diff --git a/DESCRIPTION b/DESCRIPTION index 2c97d4f..7ccbd7f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -26,7 +26,8 @@ Imports: future.apply, data.table, expm, - gridBase + gridBase, + dplyr Suggests: knitr, rmarkdown, diff --git a/R/boruta.R b/R/boruta.R index 0ba714b..ec87a56 100644 --- a/R/boruta.R +++ b/R/boruta.R @@ -27,7 +27,6 @@ #' #' @keywords tree models multivariate #' @export - boruta <- function(model, data, control = NULL, @@ -38,63 +37,21 @@ boruta <- function(model, verbose=FALSE, ...) { - # initial checks - stopifnot(percentile_threshold>=0) - stopifnot(percentile_threshold<=1) - stopifnot(is.numeric(rounds)) - stopifnot(rounds>0) - - preds_important <- c() - preds_unimportant <- c() - - cur_round = 1 - temp_vims <- list() - - while(cur_round <= rounds) { - vim_boruta <- .boruta(model=model, - data=data, - control=control, - predictors=predictors, - percentile_threshold = percentile_threshold, - ...) - browser() - # add predictors to list of unimportant variables - preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter]) - # remove them from the dataset - data <- data[, -c(preds_unimportant)] - temp_vims[[cur_round]] <-vim_boruta - } - - result <- list( - preds_unimportant, - rounds = rounds - ) - return(result) -} - -.boruta <- function(model, - data, - control = NULL, - predictors = NULL, - percentile_threshold = 1, - num_shadows = 1, - ...) { - - # make sure that no column names start with "shadow_" prefix - stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")}))) - # detect model (warning: duplicated code) - if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) { - tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors) - - } else if (inherits(model,"lavaan")){ + if (inherits(model,"MxModel") || inherits(model,"MxRAMModel")) { + tmp <- getPredictorsOpenMx(mxmodel=model, dataset=data, covariates=predictors) + model.ids <- tmp[[1]] + covariate.ids <- tmp[[2]] + # } else if (inherits(model,"lavaan")){ - tmp <- getPredictorsLavaan(model, data, predictors) + # } else if ((inherits(model,"ctsemFit")) || (inherits(model,"ctsemInit"))) { + # } else { - ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!") + ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!"); } - + + # initial checks # Checks on x & y from the boruta package if(length(grep('^shadow',covariate.ids)>0)) stop('Attributes with names starting from "shadow" are reserved for internal use. Please rename them.') @@ -114,7 +71,7 @@ boruta <- function(model, names(impHistory) <- c(predictors, "shadowMin", "shadowMean", "shadowMax") decisionList <- data.frame(predictor=predictors, decision = "Tentative", hitCount = 0, raw.p=NA, adjusted.p=NA) - + # TODO: Parallelize the first five runs. for(runNo in 1:maxRuns) { if(verbose) { @@ -143,10 +100,10 @@ boruta <- function(model, # run the forest forest <- semforest(model, current.data, control, current.predictors, ...) - + # run variable importance vim <- varimp(forest) - + # get variable importance from shadow features shadow_names <- names(current.data)[shadow.ids] agvim <- aggregateVarimp(vim, aggregate="mean") @@ -195,12 +152,43 @@ boruta <- function(model, vim$impHistory <- impHistory vim$decisions <- decisionList$decision vim$details <- decisionList - + vim$filter <- decisionList$decision == "Confirmed" # Turns into hitreg vim$boruta <- TRUE + class(vim) <- "boruta" + # TODO: Loop ends here with some reporting. return(vim) } +#' @exportS3Method plot boruta +plot.boruta = function(vim, type=0, ...) { + decisionList = vim$details + impHistory = vim$impHistory + impHistory <- impHistory |> + dplyr::mutate(rnd=1:nrow(impHistory)) |> + tidyr::pivot_longer(cols = -last_col()) |> #everything()) |> + dplyr::left_join(data.frame(decisionList), by=join_by("name"=="predictor")) |> + dplyr::mutate(decision = case_when(is.na(decision)~"Shadow", .default=decision)) |> + dplyr::group_by(name) |> mutate(median_value = median(value,na.rm=TRUE)) + + if (type==0) { + ggplot2::ggplot(impHistory, + aes(x=stats::reorder(name, median_value), + y=value, color=decision)) + + ggplot2::geom_boxplot()+ + ggplot2::xlab("")+ + ggplot2::ylab("Importance")+ + scale_color_discrete(name = "Predictor")+ + theme(axis.text.x = element_text(angle = 45, hjust = 1)) + } else if (type==1) { + ggplot2::ggplot(impHistory, + aes(x=rnd, y=value,group=name,col=name))+geom_line()+ geom_hline(aes(yintercept=median_value,col=name),lwd=2)+ + xlab("Round")+ylab("Importance")+scale_color_discrete(name = "Predictor") + } else { + stop("Unknown graph type. Please choose 0 or 1.") + } + +} \ No newline at end of file diff --git a/R/toTable.R b/R/toTable.R index 337c235..2be2bf7 100644 --- a/R/toTable.R +++ b/R/toTable.R @@ -66,7 +66,6 @@ alls <- unique(alls) # create table -#covariate.names <-simplify2array(tree$result$btn.matrix[2,]) covariate.names <- getCovariatesFromTree(tree) # default is to display all parameters @@ -79,10 +78,12 @@ all.names <- c(covariate.names, added.param.cols) str.matrix <- matrix(NA, nrow = length(rowdata),ncol=length(all.names)) +# convert to a data frame to avoid coercion to string +str.matrix <- data.frame(str.matrix) + colnames(str.matrix) <- all.names for (i in 1:length(rowdata)) { -# result.row <- rep(" ",length(covariate.names)) myrow <- rowdata[[i]][[1]] for (j in 1:length(myrow)) { myitem <- myrow[[j]] @@ -126,7 +127,6 @@ for (i in 1:length(rowdata)) { } } -# result.string <- paste(result.string,paste(result.row,collapse="\t"),"\n") } ## prune empty columns? @@ -139,7 +139,7 @@ if (length(is.col.empty)>0) { sortby <- apply(str.matrix,2,function(x){sum(!is.na(x))}) if (!is.null(added.param.cols)) { remids <- (dim(str.matrix)[2]-length(added.param.cols)+1):(dim(str.matrix)[2]) - sortby[remids] <- sortby[remids]-999999 + sortby[remids] <- sortby[remids]-999999 # bad style ^^ #TODO } sort.ix <- sort(sortby,index.return=TRUE,decreasing = TRUE)$ix str.matrix <- str.matrix[, sort.ix] @@ -156,7 +156,6 @@ str.matrix[is.na(str.matrix)]<-"" ## and display #cat(result.string) -#require("openxlsx") return(str.matrix) diff --git a/man/boruta.Rd b/man/boruta.Rd index 4921d25..6a8cc8a 100644 --- a/man/boruta.Rd +++ b/man/boruta.Rd @@ -6,7 +6,17 @@ \alias{print.boruta} \title{Run the Boruta algorithm on a sem tree} \usage{ -boruta(model, data, control = NULL, predictors = NULL, ...) +boruta( + model, + data, + control = NULL, + predictors = NULL, + maxRuns = 30, + pAdjMethod = "none", + alpha = 0.05, + verbose = FALSE, + ... +) } \arguments{ \item{model}{A template SEM. Same as in \code{semtree}.} @@ -17,12 +27,25 @@ boruta(model, data, control = NULL, predictors = NULL, ...) \item{predictors}{An optional list of covariates. See semtree code example.} -\item{\dots}{Optional parameters.} +\item{maxRuns}{Maximum number of boruta search cycles} -\item{constraints}{An optional list of covariates. See semtree code example.} +\item{pAdjMethod}{A value from \link{stats::p.adjust.methods} defining a +multiple testing correction method} + +\item{alpha}{p-value cutoff for decisionmaking. Default .05} + +\item{verbose}{Verbosity level for boruta processing +similar to the same argument in \link{semtree.control} and +\link{semforest.control}} + +\item{\dots}{Optional parameters to undefined subfunctions} } \value{ -A boruta object. +A vim object with several elements that need work. + Of particular note, `$importance` carries mean importance; + `$decision` denotes Accepted/Rejected/Tentative; + `$impHistory` has the entire varimp history; and + `$details` has exit values for each parameter. } \description{ Grows a series of SEM Forests following the boruta algorithm to determine diff --git a/man/semtree.control.Rd b/man/semtree.control.Rd index 7ce8847..df9d886 100644 --- a/man/semtree.control.Rd +++ b/man/semtree.control.Rd @@ -8,7 +8,7 @@ \usage{ semtree.control( method = c("naive", "score", "fair", "fair3"), - min.N = 20, + min.N = NULL, max.depth = NA, alpha = 0.05, alpha.invariance = NA, @@ -24,7 +24,7 @@ semtree.control( report.level = 0, exclude.code = NA, linear = TRUE, - min.bucket = 10, + min.bucket = NULL, naive.bonferroni.type = 0, missing = "ignore", use.maxlm = FALSE,