Skip to content

Commit

Permalink
export boruta
Browse files Browse the repository at this point in the history
set semtree control defaults to NULL
  • Loading branch information
brandmaier committed Apr 16, 2024
1 parent b479679 commit 8661dee
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 18 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ S3method(summary,semforest)
S3method(summary,semtree)
S3method(toLatex,semtree)
export(biodiversity)
export(boruta)
export(diversityMatrix)
export(evaluateTree)
export(fitSubmodels)
Expand Down Expand Up @@ -65,6 +66,7 @@ export(semforest_score_control)
export(semtree)
export(semtree.constraints)
export(semtree.control)
export(semtree_control)
export(strip)
export(subforest)
export(subtree)
Expand Down
91 changes: 84 additions & 7 deletions R/boruta.R
Original file line number Diff line number Diff line change
@@ -1,35 +1,107 @@
#' BORUTA algorithm for SEM trees
#'
#' This is an experimental feature. Use cautiously.
#'
#' @aliases boruta
#' @param model A template model specification from \code{\link{OpenMx}} using
#' the \code{\link{mxModel}} function (or a \code{\link[lavaan]{lavaan}} model
#' using the \code{\link[lavaan]{lavaan}} function with option fit=FALSE).
#' Model must be syntactically correct within the framework chosen, and
#' converge to a solution.
#' @param data Data.frame used in the model creation using
#' \code{\link{mxModel}} or \code{\link[lavaan]{lavaan}} are input here. Order
#' of modeled variables and predictors is not important when providing a
#' dataset to \code{semtree}.
#' @param control \code{\link{semtree}} model specifications from
#' \code{\link{semtree.control}} are input here. Any changes from the default
#' setting can be specified here.
#' @param percentile_threshold Numeric.
#' @param rounds Numeric. Number of rounds of the BORUTA algorithm.
#'
#' @export
#'
boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
rounds = 1,
...) {
# TODO: make sure that no column names start with "shadow_" prefix

# initial checks
stopifnot(percentile_threshold>=0)
stopifnot(percentile_threshold<=1)
stopifnot(is.numeric(rounds))
stopifnot(rounds>0)

preds_important <- c()
preds_unimportant <- c()

cur_round = 1
temp_vims <- list()

while(cur_round <= rounds) {
vim_boruta <- .boruta(model=model,
data=data,
control=control,
predictors=predictors,
percentile_threshold = percentile_threshold,
...)
browser()
# add predictors to list of unimportant variables
preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter])
# remove them from the dataset
data <- data[, -c(preds_unimportant)]
temp_vims[[cur_round]] <-vim_boruta
}

result <- list(
preds_unimportant,
rounds = rounds
)

return(result)
}

.boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
num_shadows = 1,
...) {

# make sure that no column names start with "shadow_" prefix
stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")})))

# detect model (warning: duplicated code)
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) {
tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors)
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]
# } else if (inherits(model,"lavaan")){

# } else if ((inherits(model,"ctsemFit")) || (inherits(model,"ctsemInit"))) {
#
} else if (inherits(model,"lavaan")){

tmp <- getPredictorsLavaan(model, data, predictors)
} else {
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!")
}
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]

# stage 1 - create shadow features

shadow.ids <- (ncol(data) + 1):(ncol(data) + length(covariate.ids))

for (cur_cov_id in covariate.ids) {
for (rep_id in 1:num_shadows) {
# pick column and shuffle
temp_column <- data[, cur_cov_id]
temp_column <- sample(temp_column, length(temp_column), replace = FALSE)
# add to dataset as shadow feature
temp_colname <- paste0("shadow_", names(data)[cur_cov_id], collapse = "")
if (num_shadows>1) temp_colname <- paste0(temp_colname, rep_id, collapse = "")
data[temp_colname] <- temp_column
if (!is.null(predictors)) predictors <- c(predictors, temp_colname)
}
}

# run the forest
Expand All @@ -41,14 +113,19 @@ boruta <- function(model,
# get variable importance from shadow features
shadow_names <- names(data)[shadow.ids]
agvim <- aggregateVarimp(vim, aggregate = "mean")
max_shadow_importance <- max(agvim[names(agvim) %in% shadow_names])

vals <- agvim[names(agvim) %in% shadow_names]
#max_shadow_importance <- max(vals)
max_shadow_importance <- quantile(vals, percentile_threshold)

agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)]

df <- data.frame(importance = agvim_filtered, predictor = names(agvim_filtered))

vim$filter <- agvim_filtered > max_shadow_importance
vim$boruta <- TRUE
vim$boruta_threshold <- max_shadow_importance
vim$percentile_threshold <- percentile_threshold

return(vim)
}
29 changes: 27 additions & 2 deletions R/semtree.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,41 @@ semtree <- function(model, data = NULL, control = NULL, constraints = NULL,
}
}

# here we decide between four cases depending
# on whether min.N is given and/or min.bucket is given
# this is a really dumb heuristic
# please can someone replace this with something more useful
# this based on (Bentler & Chou, 1987; see also Bollen, 1989)

if (is.null(control$min.N)) {
control$min.N <- 5 * npar(model)

if (is.null(control$min.bucket)) {
# both values were not specified
control$min.N <- max(20, 5 * npar(model))
control$min.bucket <- max(10, control$min.N / 2)
} else {
# only min.bucket was given, min.N was not specified
control$min.N <- control$min.bucket * 2
}
} else {
if (is.null(control$min.bucket)) {
# only min.N was given, min.bucket was not specified
control$min.bucket <- max(10, control$min.N / 2)
} else {
# do nothing, both values were specified
if (control$min.bucket > control$min.N) {
warning("Min.bucket parameter should probably be smaller than min.N!")
}
}
}

if (is.null(control$min.N)) {

}

# set min.bucket and min.N heuristically
if (is.null(control$min.bucket)) {
control$min.bucket <- control$min.N / 2

}

if (control$method == "cv") {
Expand Down
6 changes: 2 additions & 4 deletions R/semtree.control.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
#' @export
semtree.control <-
function(method = c("naive","score","fair","fair3"),
min.N = 20,
min.N = NULL,
max.depth = NA,
alpha = .05,
alpha.invariance = NA,
Expand All @@ -119,7 +119,7 @@ semtree.control <-
# ordinal = 'maxLMo', # and maxLM are available
# metric = 'maxLM'),
linear = TRUE,
min.bucket = 10,
min.bucket = NULL,
naive.bonferroni.type = 0,
missing = 'ignore',
use.maxlm = FALSE,
Expand Down Expand Up @@ -200,8 +200,6 @@ semtree.control <-
return(options)
}



#' @export
semtree_control <- function(...) {
semtree.control(...)
Expand Down
9 changes: 6 additions & 3 deletions R/varimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ varimp <- function(forest,
colnames(result$importance.level1) <- var.names
}


if (dim(result$importance)[1] == 1) {
#result$importance<-t(result$importance)
result$importance<-t(result$importance)

# TODO: this is stupid, should be as.matrix?! or something else
result$ll.baselines <-
t(t(result$ll.baselines)) # TODO: this is stupid, should be as.matrix?!
}
t(t(result$ll.baselines))
}

colnames(result$importance) <- var.names
result$var.names <- var.names
Expand Down
39 changes: 39 additions & 0 deletions man/boruta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/semtree.control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_boruta.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

# skip long running tests on CRAN
skip_on_cran()
testthat::skip_on_cran()


library("semtree")
Expand Down Expand Up @@ -75,7 +75,7 @@ model <- lgcModel
data <- lgcm
control <- semforest_score_control()

vim_boruta <- boruta(lgcModel, lgcm)
vim_boruta <- boruta(lgcModel, lgcm,percentile_threshold = 1)

print(vim_boruta)

Expand Down

0 comments on commit 8661dee

Please sign in to comment.