-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
manual copy of BORUTA progress into master (because of mess with auth…
…or names in commits)
- Loading branch information
1 parent
c014e57
commit c14c555
Showing
7 changed files
with
407 additions
and
231 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ Imports: | |
cluster, | ||
ggplot2, | ||
tidyr, | ||
dplyr, | ||
methods, | ||
strucchange, | ||
sandwich, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,131 +1,255 @@ | ||
#' BORUTA algorithm for SEM trees | ||
#' | ||
#' This is an experimental feature. Use cautiously. | ||
#' | ||
#' @aliases boruta | ||
#' @param model A template model specification from \code{\link{OpenMx}} using | ||
#' the \code{\link{mxModel}} function (or a \code{\link[lavaan]{lavaan}} model | ||
#' using the \code{\link[lavaan]{lavaan}} function with option fit=FALSE). | ||
#' Model must be syntactically correct within the framework chosen, and | ||
#' converge to a solution. | ||
#' @param data Data.frame used in the model creation using | ||
#' \code{\link{mxModel}} or \code{\link[lavaan]{lavaan}} are input here. Order | ||
#' of modeled variables and predictors is not important when providing a | ||
#' dataset to \code{semtree}. | ||
#' @param control \code{\link{semtree}} model specifications from | ||
#' \code{\link{semtree.control}} are input here. Any changes from the default | ||
#' setting can be specified here. | ||
#' @param percentile_threshold Numeric. | ||
#' @param rounds Numeric. Number of rounds of the BORUTA algorithm. | ||
#' Run the Boruta algorithm on a sem tree | ||
#' | ||
#' Grows a series of SEM Forests following the boruta algorithm to determine | ||
#' feature importance as moderators of the underlying model. | ||
#' | ||
#' | ||
#' @aliases boruta plot.boruta print.boruta | ||
#' @param model A template SEM. Same as in \code{semtree}. | ||
#' @param data A dataframe to boruta on. Same as in \code{semtree}. | ||
#' @param control A semforest control object to set forest parameters. | ||
#' @param predictors An optional list of covariates. See semtree code example. | ||
#' @param maxRuns Maximum number of boruta search cycles | ||
#' @param pAdjMethod A value from \link{stats::p.adjust.methods} defining a | ||
#' multiple testing correction method | ||
#' @param alpha p-value cutoff for decisionmaking. Default .05 | ||
#' @param verbose Verbosity level for boruta processing | ||
#' similar to the same argument in \link{semtree.control} and | ||
#' \link{semforest.control} | ||
#' @param \dots Optional parameters to undefined subfunctions | ||
#' @return A vim object with several elements that need work. | ||
#' Of particular note, `$importance` carries mean importance; | ||
#' `$decision` denotes Accepted/Rejected/Tentative; | ||
#' `$impHistory` has the entire varimp history; and | ||
#' `$details` has exit values for each parameter. | ||
#' @author Priyanka Paul, Timothy R. Brick, Andreas Brandmaier | ||
#' @seealso \code{\link{semtree}} \code{\link{semforest}} | ||
#' | ||
#' @keywords tree models multivariate | ||
#' @export | ||
#' | ||
boruta <- function(model, | ||
data, | ||
control = NULL, | ||
predictors = NULL, | ||
percentile_threshold = 1, | ||
rounds = 1, | ||
maxRuns = 30, | ||
pAdjMethod = "none", | ||
alpha = .05, | ||
verbose = FALSE, | ||
quant = 1, | ||
...) { | ||
# detect model (warning: duplicated code) | ||
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) { | ||
tmp <- | ||
getPredictorsOpenMx(mxmodel = model, | ||
dataset = data, | ||
covariates = predictors) | ||
model.ids <- tmp[[1]] | ||
covariate.ids <- tmp[[2]] | ||
} else if (inherits(model, "lavaan")) { | ||
tmp <- | ||
getPredictorsLavaan(model, dataset = data, covariates = predictors) | ||
model.ids <- tmp[[1]] | ||
covariate.ids <- tmp[[2]] | ||
} else { | ||
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!") | ||
|
||
} | ||
|
||
# initial checks | ||
stopifnot(percentile_threshold>=0) | ||
stopifnot(percentile_threshold<=1) | ||
stopifnot(is.numeric(rounds)) | ||
stopifnot(rounds>0) | ||
# Checks on x & y from the boruta package | ||
if (length(grep('^shadow', covariate.ids) > 0)) | ||
stop( | ||
'Attributes with names starting from "shadow" are reserved for internal use. Please rename them.' | ||
) | ||
if (maxRuns < 11) | ||
stop('maxRuns must be greater than 10.') | ||
if (!pAdjMethod %in% stats::p.adjust.methods) | ||
stop(c( | ||
'P-value adjustment method not found. Must be one of:', | ||
stats::p.adjust.methods | ||
)) | ||
|
||
preds_important <- c() | ||
preds_unimportant <- c() | ||
|
||
cur_round = 1 | ||
temp_vims <- list() | ||
|
||
while(cur_round <= rounds) { | ||
vim_boruta <- .boruta(model=model, | ||
data=data, | ||
control=control, | ||
predictors=predictors, | ||
percentile_threshold = percentile_threshold, | ||
...) | ||
browser() | ||
# add predictors to list of unimportant variables | ||
preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter]) | ||
# remove them from the dataset | ||
data <- data[, -c(preds_unimportant)] | ||
temp_vims[[cur_round]] <-vim_boruta | ||
# Might clash with some other semtrees stuff | ||
if (is.null(predictors)) { | ||
predictors <- names(data)[covariate.ids] | ||
} | ||
|
||
result <- list( | ||
preds_unimportant, | ||
rounds = rounds | ||
) | ||
# Initialize and then loop over runs: | ||
impHistory <- | ||
data.frame(matrix(NA, nrow = 0, ncol = length(predictors) + 3)) | ||
names(impHistory) <- | ||
c(predictors, "shadowMin", "shadowMean", "shadowMax") | ||
decisionList <- | ||
data.frame( | ||
predictor = predictors, | ||
decision = "Tentative", | ||
hitCount = 0, | ||
raw.p = NA, | ||
adjusted.p = NA | ||
) | ||
|
||
return(result) | ||
} | ||
|
||
.boruta <- function(model, | ||
data, | ||
control = NULL, | ||
predictors = NULL, | ||
percentile_threshold = 1, | ||
num_shadows = 1, | ||
...) { | ||
|
||
# make sure that no column names start with "shadow_" prefix | ||
stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")}))) | ||
|
||
# detect model (warning: duplicated code) | ||
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) { | ||
tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors) | ||
|
||
} else if (inherits(model,"lavaan")){ | ||
# TODO: Parallelize the first five runs. | ||
end_time <- NULL | ||
for (runNo in 1:maxRuns) { | ||
start_time <- Sys.time() | ||
if (verbose) { | ||
time_info <- "" | ||
if (!is.null(end_time)) time_info <- paste0("Last run took", (end_time-start_time)) | ||
message(paste("Beginning Run", runNo," ",time_info)) | ||
} | ||
|
||
tmp <- getPredictorsLavaan(model, data, predictors) | ||
} else { | ||
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!") | ||
} | ||
model.ids <- tmp[[1]] | ||
covariate.ids <- tmp[[2]] | ||
|
||
# stage 1 - create shadow features | ||
|
||
shadow.ids <- (ncol(data) + 1):(ncol(data) + length(covariate.ids)) | ||
|
||
for (cur_cov_id in covariate.ids) { | ||
for (rep_id in 1:num_shadows) { | ||
# pick column and shuffle | ||
temp_column <- data[, cur_cov_id] | ||
temp_column <- sample(temp_column, length(temp_column), replace = FALSE) | ||
# add to dataset as shadow feature | ||
temp_colname <- paste0("shadow_", names(data)[cur_cov_id], collapse = "") | ||
if (num_shadows>1) temp_colname <- paste0(temp_colname, rep_id, collapse = "") | ||
data[temp_colname] <- temp_column | ||
if (!is.null(predictors)) predictors <- c(predictors, temp_colname) | ||
# stage 1 - create shadow features | ||
rejected <- | ||
decisionList$predictor[decisionList$decision == "Rejected"] | ||
current.predictors <- setdiff(predictors, rejected) | ||
current.covariate.ids <- | ||
setdiff(covariate.ids, names(data) %in% rejected) | ||
current.data <- data[, setdiff(names(data), rejected)] | ||
|
||
shadow.ids <- | ||
(ncol(current.data) + 1):(ncol(current.data) + length(current.covariate.ids)) | ||
|
||
for (cur_cov_id in current.covariate.ids) { | ||
# pick column and shuffle | ||
temp_column <- current.data[, cur_cov_id] | ||
temp_column <- | ||
sample(temp_column, length(temp_column), replace = FALSE) | ||
# add to dataset as shadow feature | ||
temp_colname <- | ||
paste0("shadow_", names(current.data)[cur_cov_id], collapse = "") | ||
current.data[temp_colname] <- temp_column | ||
if (!is.null(current.predictors)) | ||
current.predictors <- c(current.predictors, temp_colname) | ||
} | ||
|
||
# TODO: Pre-run model if needed. | ||
|
||
# run the forest | ||
forest <- | ||
semforest(model, current.data, control, current.predictors, ...) | ||
|
||
# run variable importance | ||
vim <- varimp(forest) | ||
|
||
# get variable importance from shadow features | ||
shadow_names <- names(current.data)[shadow.ids] | ||
agvim <- aggregateVarimp(vim, aggregate = "mean") | ||
|
||
# Compute shadow stats | ||
shadow_importances <- agvim[names(agvim) %in% shadow_names] | ||
impHistory[runNo, "shadowMax"] <- max(shadow_importances, na.rm=TRUE) | ||
|
||
max_shadow_importance <- stats::quantile(shadow_importances, | ||
probs=quant,na.rm=TRUE) | ||
|
||
impHistory[runNo, "shadowMin"] <- min(shadow_importances, na.rm=TRUE) | ||
impHistory[runNo, "shadowMean"] <- mean(shadow_importances, na.rm=TRUE) | ||
agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)] | ||
impHistory[runNo, names(agvim_filtered)] <- agvim_filtered | ||
|
||
# Compute "hits" | ||
hits <- | ||
decisionList$predictor %in% names(agvim_filtered[agvim_filtered > max_shadow_importance]) | ||
decisionList$hitCount[hits] <- decisionList$hitCount[hits] + 1 | ||
|
||
# Run tests. | ||
# The biasing here means that there are no decisions without correction | ||
# before 5 runs and no decisions with Bonferroni before 7 runs. | ||
|
||
# Run confirmation tests (pulled from Boruta package) | ||
newPs <- | ||
stats::pbinom(decisionList$hitCount - 1, runNo, 0.5, lower.tail = FALSE) | ||
adjPs <- stats::p.adjust(newPs, method = pAdjMethod) | ||
acceptable <- adjPs < alpha | ||
updateList <- acceptable & decisionList$decision == "Tentative" | ||
decisionList$raw.p[updateList] <- newPs[updateList] | ||
decisionList$adjusted.p[updateList] <- adjPs[updateList] | ||
decisionList$decision[updateList] <- "Confirmed" | ||
|
||
# Run rejection tests (pulled from Boruta package) | ||
newPs <- | ||
stats::pbinom(decisionList$hitCount, runNo, 0.5, lower.tail = TRUE) | ||
adjPs <- stats::p.adjust(newPs, method = pAdjMethod) | ||
acceptable <- adjPs < alpha | ||
updateList <- acceptable & decisionList$decision == "Tentative" | ||
decisionList$raw.p[updateList] <- newPs[updateList] | ||
decisionList$adjusted.p[updateList] <- adjPs[updateList] | ||
decisionList$decision[updateList] <- "Rejected" | ||
|
||
if (!any(decisionList$decision == "Tentative")) { | ||
break | ||
} | ||
|
||
end_time <- Sys.time() | ||
|
||
} | ||
|
||
# run the forest | ||
forest <- semforest(model, data, control, predictors, ...) | ||
|
||
# run variable importance | ||
vim <- varimp(forest) | ||
|
||
# get variable importance from shadow features | ||
shadow_names <- names(data)[shadow.ids] | ||
agvim <- aggregateVarimp(vim, aggregate = "mean") | ||
|
||
vals <- agvim[names(agvim) %in% shadow_names] | ||
#max_shadow_importance <- max(vals) | ||
max_shadow_importance <- quantile(vals, percentile_threshold) | ||
vim$importance <- colMeans(impHistory, na.rm = TRUE) | ||
vim$impHistory <- impHistory | ||
vim$decisions <- decisionList$decision | ||
vim$details <- decisionList | ||
|
||
agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)] | ||
|
||
df <- data.frame(importance = agvim_filtered, predictor = names(agvim_filtered)) | ||
|
||
vim$filter <- agvim_filtered > max_shadow_importance | ||
vim$filter <- | ||
decisionList$decision == "Confirmed" # Turns into hitreg | ||
vim$boruta <- TRUE | ||
vim$boruta_threshold <- max_shadow_importance | ||
vim$percentile_threshold <- percentile_threshold | ||
|
||
|
||
class(vim) <- "boruta" | ||
|
||
# TODO: Loop ends here with some reporting. | ||
|
||
return(vim) | ||
} | ||
|
||
|
||
#' @exportS3Method plot boruta | ||
plot.boruta = function(vim, type = 0, ...) { | ||
decisionList = vim$details | ||
impHistory = vim$impHistory | ||
impHistory <- impHistory |> | ||
dplyr::mutate(rnd = 1:nrow(impHistory)) |> | ||
tidyr::pivot_longer(cols = -last_col()) |> #everything()) |> | ||
dplyr::left_join(data.frame(decisionList), | ||
by = dplyr::join_by("name" == "predictor")) |> | ||
dplyr::mutate(decision = | ||
dplyr::case_when(is.na(decision) ~ "Shadow", .default = decision)) |> | ||
dplyr::group_by(name) |> | ||
dplyr::mutate(median_value = median(value, na.rm = TRUE)) | ||
|
||
if (type == 0) { | ||
|
||
# sort Inf values to the left | ||
#impHistory[is.na(impHistory$median_value)] <- -Inf | ||
impHistory$sort_value <- impHistory$median_value | ||
# mv <- min(impHistory$sort_value,na.rm=TRUE) | ||
# impHistory$sort_value[impHistory$decision=="Rejected"]<- (mv-1) | ||
|
||
ggplot2::ggplot(impHistory, | ||
ggplot2::aes( | ||
x = stats::reorder(name, sort_value), | ||
y = value, | ||
color = decision | ||
)) + | ||
ggplot2::geom_boxplot() + | ||
ggplot2::xlab("") + | ||
ggplot2::ylab("Importance") + | ||
ggplot2::scale_color_discrete(name = "Decision") + | ||
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) | ||
} else if (type == 1) { | ||
ggplot2::ggplot(impHistory, | ||
ggplot2::aes( | ||
x = rnd, | ||
y = value, | ||
group = name, | ||
col = name | ||
)) + | ||
ggplot2::geom_line() + | ||
ggplot2::geom_hline(ggplot2::aes(yintercept = median_value, col = name), lwd = | ||
2) + | ||
ggplot2::xlab("Round") + | ||
ggplot2::ylab("Importance") + | ||
ggplot2::scale_color_discrete(name = "Predictor") | ||
} else { | ||
stop("Unknown graph type. Please choose 0 or 1.") | ||
} | ||
|
||
} |
Oops, something went wrong.