Skip to content

Commit

Permalink
manual copy of BORUTA progress into master (because of mess with auth…
Browse files Browse the repository at this point in the history
…or names in commits)
  • Loading branch information
brandmaier committed Sep 30, 2024
1 parent c014e57 commit c14c555
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 231 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Imports:
cluster,
ggplot2,
tidyr,
dplyr,
methods,
strucchange,
sandwich,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ S3method("[",semtree)
S3method(evaluate,semforest)
S3method(logLik,semtree)
S3method(merge,semforest)
S3method(merge,semforest.varimp)
S3method(nobs,semtree)
S3method(partialDependence,semforest)
S3method(partialDependence,semforest_stripped)
S3method(plot,boruta)
S3method(plot,diversityMatrix)
S3method(plot,partialDependence)
S3method(plot,semforest.proximity)
Expand Down
344 changes: 234 additions & 110 deletions R/boruta.R
Original file line number Diff line number Diff line change
@@ -1,131 +1,255 @@
#' BORUTA algorithm for SEM trees
#'
#' This is an experimental feature. Use cautiously.
#'
#' @aliases boruta
#' @param model A template model specification from \code{\link{OpenMx}} using
#' the \code{\link{mxModel}} function (or a \code{\link[lavaan]{lavaan}} model
#' using the \code{\link[lavaan]{lavaan}} function with option fit=FALSE).
#' Model must be syntactically correct within the framework chosen, and
#' converge to a solution.
#' @param data Data.frame used in the model creation using
#' \code{\link{mxModel}} or \code{\link[lavaan]{lavaan}} are input here. Order
#' of modeled variables and predictors is not important when providing a
#' dataset to \code{semtree}.
#' @param control \code{\link{semtree}} model specifications from
#' \code{\link{semtree.control}} are input here. Any changes from the default
#' setting can be specified here.
#' @param percentile_threshold Numeric.
#' @param rounds Numeric. Number of rounds of the BORUTA algorithm.
#' Run the Boruta algorithm on a sem tree
#'
#' Grows a series of SEM Forests following the boruta algorithm to determine
#' feature importance as moderators of the underlying model.
#'
#'
#' @aliases boruta plot.boruta print.boruta
#' @param model A template SEM. Same as in \code{semtree}.
#' @param data A dataframe to boruta on. Same as in \code{semtree}.
#' @param control A semforest control object to set forest parameters.
#' @param predictors An optional list of covariates. See semtree code example.
#' @param maxRuns Maximum number of boruta search cycles
#' @param pAdjMethod A value from \link{stats::p.adjust.methods} defining a
#' multiple testing correction method
#' @param alpha p-value cutoff for decisionmaking. Default .05
#' @param verbose Verbosity level for boruta processing
#' similar to the same argument in \link{semtree.control} and
#' \link{semforest.control}
#' @param \dots Optional parameters to undefined subfunctions
#' @return A vim object with several elements that need work.
#' Of particular note, `$importance` carries mean importance;
#' `$decision` denotes Accepted/Rejected/Tentative;
#' `$impHistory` has the entire varimp history; and
#' `$details` has exit values for each parameter.
#' @author Priyanka Paul, Timothy R. Brick, Andreas Brandmaier
#' @seealso \code{\link{semtree}} \code{\link{semforest}}
#'
#' @keywords tree models multivariate
#' @export
#'
boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
rounds = 1,
maxRuns = 30,
pAdjMethod = "none",
alpha = .05,
verbose = FALSE,
quant = 1,
...) {
# detect model (warning: duplicated code)
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) {
tmp <-
getPredictorsOpenMx(mxmodel = model,
dataset = data,
covariates = predictors)
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]
} else if (inherits(model, "lavaan")) {
tmp <-
getPredictorsLavaan(model, dataset = data, covariates = predictors)
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]
} else {
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!")

}

# initial checks
stopifnot(percentile_threshold>=0)
stopifnot(percentile_threshold<=1)
stopifnot(is.numeric(rounds))
stopifnot(rounds>0)
# Checks on x & y from the boruta package
if (length(grep('^shadow', covariate.ids) > 0))
stop(
'Attributes with names starting from "shadow" are reserved for internal use. Please rename them.'
)
if (maxRuns < 11)
stop('maxRuns must be greater than 10.')
if (!pAdjMethod %in% stats::p.adjust.methods)
stop(c(
'P-value adjustment method not found. Must be one of:',
stats::p.adjust.methods
))

preds_important <- c()
preds_unimportant <- c()

cur_round = 1
temp_vims <- list()

while(cur_round <= rounds) {
vim_boruta <- .boruta(model=model,
data=data,
control=control,
predictors=predictors,
percentile_threshold = percentile_threshold,
...)
browser()
# add predictors to list of unimportant variables
preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter])
# remove them from the dataset
data <- data[, -c(preds_unimportant)]
temp_vims[[cur_round]] <-vim_boruta
# Might clash with some other semtrees stuff
if (is.null(predictors)) {
predictors <- names(data)[covariate.ids]
}

result <- list(
preds_unimportant,
rounds = rounds
)
# Initialize and then loop over runs:
impHistory <-
data.frame(matrix(NA, nrow = 0, ncol = length(predictors) + 3))
names(impHistory) <-
c(predictors, "shadowMin", "shadowMean", "shadowMax")
decisionList <-
data.frame(
predictor = predictors,
decision = "Tentative",
hitCount = 0,
raw.p = NA,
adjusted.p = NA
)

return(result)
}

.boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
num_shadows = 1,
...) {

# make sure that no column names start with "shadow_" prefix
stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")})))

# detect model (warning: duplicated code)
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) {
tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors)

} else if (inherits(model,"lavaan")){
# TODO: Parallelize the first five runs.
end_time <- NULL
for (runNo in 1:maxRuns) {
start_time <- Sys.time()
if (verbose) {
time_info <- ""
if (!is.null(end_time)) time_info <- paste0("Last run took", (end_time-start_time))
message(paste("Beginning Run", runNo," ",time_info))
}

tmp <- getPredictorsLavaan(model, data, predictors)
} else {
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!")
}
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]

# stage 1 - create shadow features

shadow.ids <- (ncol(data) + 1):(ncol(data) + length(covariate.ids))

for (cur_cov_id in covariate.ids) {
for (rep_id in 1:num_shadows) {
# pick column and shuffle
temp_column <- data[, cur_cov_id]
temp_column <- sample(temp_column, length(temp_column), replace = FALSE)
# add to dataset as shadow feature
temp_colname <- paste0("shadow_", names(data)[cur_cov_id], collapse = "")
if (num_shadows>1) temp_colname <- paste0(temp_colname, rep_id, collapse = "")
data[temp_colname] <- temp_column
if (!is.null(predictors)) predictors <- c(predictors, temp_colname)
# stage 1 - create shadow features
rejected <-
decisionList$predictor[decisionList$decision == "Rejected"]
current.predictors <- setdiff(predictors, rejected)
current.covariate.ids <-
setdiff(covariate.ids, names(data) %in% rejected)
current.data <- data[, setdiff(names(data), rejected)]

shadow.ids <-
(ncol(current.data) + 1):(ncol(current.data) + length(current.covariate.ids))

for (cur_cov_id in current.covariate.ids) {
# pick column and shuffle
temp_column <- current.data[, cur_cov_id]
temp_column <-
sample(temp_column, length(temp_column), replace = FALSE)
# add to dataset as shadow feature
temp_colname <-
paste0("shadow_", names(current.data)[cur_cov_id], collapse = "")
current.data[temp_colname] <- temp_column
if (!is.null(current.predictors))
current.predictors <- c(current.predictors, temp_colname)
}

# TODO: Pre-run model if needed.

# run the forest
forest <-
semforest(model, current.data, control, current.predictors, ...)

# run variable importance
vim <- varimp(forest)

# get variable importance from shadow features
shadow_names <- names(current.data)[shadow.ids]
agvim <- aggregateVarimp(vim, aggregate = "mean")

# Compute shadow stats
shadow_importances <- agvim[names(agvim) %in% shadow_names]
impHistory[runNo, "shadowMax"] <- max(shadow_importances, na.rm=TRUE)

max_shadow_importance <- stats::quantile(shadow_importances,
probs=quant,na.rm=TRUE)

impHistory[runNo, "shadowMin"] <- min(shadow_importances, na.rm=TRUE)
impHistory[runNo, "shadowMean"] <- mean(shadow_importances, na.rm=TRUE)
agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)]
impHistory[runNo, names(agvim_filtered)] <- agvim_filtered

# Compute "hits"
hits <-
decisionList$predictor %in% names(agvim_filtered[agvim_filtered > max_shadow_importance])
decisionList$hitCount[hits] <- decisionList$hitCount[hits] + 1

# Run tests.
# The biasing here means that there are no decisions without correction
# before 5 runs and no decisions with Bonferroni before 7 runs.

# Run confirmation tests (pulled from Boruta package)
newPs <-
stats::pbinom(decisionList$hitCount - 1, runNo, 0.5, lower.tail = FALSE)
adjPs <- stats::p.adjust(newPs, method = pAdjMethod)
acceptable <- adjPs < alpha
updateList <- acceptable & decisionList$decision == "Tentative"
decisionList$raw.p[updateList] <- newPs[updateList]
decisionList$adjusted.p[updateList] <- adjPs[updateList]
decisionList$decision[updateList] <- "Confirmed"

# Run rejection tests (pulled from Boruta package)
newPs <-
stats::pbinom(decisionList$hitCount, runNo, 0.5, lower.tail = TRUE)
adjPs <- stats::p.adjust(newPs, method = pAdjMethod)
acceptable <- adjPs < alpha
updateList <- acceptable & decisionList$decision == "Tentative"
decisionList$raw.p[updateList] <- newPs[updateList]
decisionList$adjusted.p[updateList] <- adjPs[updateList]
decisionList$decision[updateList] <- "Rejected"

if (!any(decisionList$decision == "Tentative")) {
break
}

end_time <- Sys.time()

}

# run the forest
forest <- semforest(model, data, control, predictors, ...)

# run variable importance
vim <- varimp(forest)

# get variable importance from shadow features
shadow_names <- names(data)[shadow.ids]
agvim <- aggregateVarimp(vim, aggregate = "mean")

vals <- agvim[names(agvim) %in% shadow_names]
#max_shadow_importance <- max(vals)
max_shadow_importance <- quantile(vals, percentile_threshold)
vim$importance <- colMeans(impHistory, na.rm = TRUE)
vim$impHistory <- impHistory
vim$decisions <- decisionList$decision
vim$details <- decisionList

agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)]

df <- data.frame(importance = agvim_filtered, predictor = names(agvim_filtered))

vim$filter <- agvim_filtered > max_shadow_importance
vim$filter <-
decisionList$decision == "Confirmed" # Turns into hitreg
vim$boruta <- TRUE
vim$boruta_threshold <- max_shadow_importance
vim$percentile_threshold <- percentile_threshold


class(vim) <- "boruta"

# TODO: Loop ends here with some reporting.

return(vim)
}


#' @exportS3Method plot boruta
plot.boruta = function(vim, type = 0, ...) {
decisionList = vim$details
impHistory = vim$impHistory
impHistory <- impHistory |>
dplyr::mutate(rnd = 1:nrow(impHistory)) |>
tidyr::pivot_longer(cols = -last_col()) |> #everything()) |>
dplyr::left_join(data.frame(decisionList),
by = dplyr::join_by("name" == "predictor")) |>
dplyr::mutate(decision =
dplyr::case_when(is.na(decision) ~ "Shadow", .default = decision)) |>
dplyr::group_by(name) |>
dplyr::mutate(median_value = median(value, na.rm = TRUE))

if (type == 0) {

# sort Inf values to the left
#impHistory[is.na(impHistory$median_value)] <- -Inf
impHistory$sort_value <- impHistory$median_value
# mv <- min(impHistory$sort_value,na.rm=TRUE)
# impHistory$sort_value[impHistory$decision=="Rejected"]<- (mv-1)

ggplot2::ggplot(impHistory,
ggplot2::aes(
x = stats::reorder(name, sort_value),
y = value,
color = decision
)) +
ggplot2::geom_boxplot() +
ggplot2::xlab("") +
ggplot2::ylab("Importance") +
ggplot2::scale_color_discrete(name = "Decision") +
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))
} else if (type == 1) {
ggplot2::ggplot(impHistory,
ggplot2::aes(
x = rnd,
y = value,
group = name,
col = name
)) +
ggplot2::geom_line() +
ggplot2::geom_hline(ggplot2::aes(yintercept = median_value, col = name), lwd =
2) +
ggplot2::xlab("Round") +
ggplot2::ylab("Importance") +
ggplot2::scale_color_discrete(name = "Predictor")
} else {
stop("Unknown graph type. Please choose 0 or 1.")
}

}
Loading

0 comments on commit c14c555

Please sign in to comment.