Skip to content

Commit

Permalink
improved plotting functionality for BORUTA
Browse files Browse the repository at this point in the history
fixed coercion problem in tables
  • Loading branch information
LeonieHagitte committed Sep 5, 2024
1 parent e196b4b commit 731bebb
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 69 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ Imports:
future.apply,
data.table,
expm,
gridBase
gridBase,
dplyr
Suggests:
knitr,
rmarkdown,
Expand Down
102 changes: 45 additions & 57 deletions R/boruta.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#'
#' @keywords tree models multivariate
#' @export

boruta <- function(model,
data,
control = NULL,
Expand All @@ -38,63 +37,21 @@ boruta <- function(model,
verbose=FALSE,
...) {

# initial checks
stopifnot(percentile_threshold>=0)
stopifnot(percentile_threshold<=1)
stopifnot(is.numeric(rounds))
stopifnot(rounds>0)

preds_important <- c()
preds_unimportant <- c()

cur_round = 1
temp_vims <- list()

while(cur_round <= rounds) {
vim_boruta <- .boruta(model=model,
data=data,
control=control,
predictors=predictors,
percentile_threshold = percentile_threshold,
...)
browser()
# add predictors to list of unimportant variables
preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter])
# remove them from the dataset
data <- data[, -c(preds_unimportant)]
temp_vims[[cur_round]] <-vim_boruta
}

result <- list(
preds_unimportant,
rounds = rounds
)

return(result)
}

.boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
num_shadows = 1,
...) {

# make sure that no column names start with "shadow_" prefix
stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")})))

# detect model (warning: duplicated code)
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) {
tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors)

} else if (inherits(model,"lavaan")){
if (inherits(model,"MxModel") || inherits(model,"MxRAMModel")) {
tmp <- getPredictorsOpenMx(mxmodel=model, dataset=data, covariates=predictors)
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]
# } else if (inherits(model,"lavaan")){

tmp <- getPredictorsLavaan(model, data, predictors)
# } else if ((inherits(model,"ctsemFit")) || (inherits(model,"ctsemInit"))) {
#
} else {
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!")
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!");
}


# initial checks
# Checks on x & y from the boruta package
if(length(grep('^shadow',covariate.ids)>0))
stop('Attributes with names starting from "shadow" are reserved for internal use. Please rename them.')
Expand All @@ -114,7 +71,7 @@ boruta <- function(model,
names(impHistory) <- c(predictors, "shadowMin", "shadowMean", "shadowMax")
decisionList <- data.frame(predictor=predictors, decision = "Tentative",
hitCount = 0, raw.p=NA, adjusted.p=NA)

# TODO: Parallelize the first five runs.
for(runNo in 1:maxRuns) {
if(verbose) {
Expand Down Expand Up @@ -143,10 +100,10 @@ boruta <- function(model,

# run the forest
forest <- semforest(model, current.data, control, current.predictors, ...)

# run variable importance
vim <- varimp(forest)

# get variable importance from shadow features
shadow_names <- names(current.data)[shadow.ids]
agvim <- aggregateVarimp(vim, aggregate="mean")
Expand Down Expand Up @@ -195,12 +152,43 @@ boruta <- function(model,
vim$impHistory <- impHistory
vim$decisions <- decisionList$decision
vim$details <- decisionList

vim$filter <- decisionList$decision == "Confirmed" # Turns into hitreg
vim$boruta <- TRUE

class(vim) <- "boruta"

# TODO: Loop ends here with some reporting.

return(vim)
}

#' @exportS3Method plot boruta
plot.boruta = function(vim, type=0, ...) {
decisionList = vim$details
impHistory = vim$impHistory
impHistory <- impHistory |>
dplyr::mutate(rnd=1:nrow(impHistory)) |>
tidyr::pivot_longer(cols = -last_col()) |> #everything()) |>
dplyr::left_join(data.frame(decisionList), by=join_by("name"=="predictor")) |>
dplyr::mutate(decision = case_when(is.na(decision)~"Shadow", .default=decision)) |>
dplyr::group_by(name) |> mutate(median_value = median(value,na.rm=TRUE))

if (type==0) {
ggplot2::ggplot(impHistory,
aes(x=stats::reorder(name, median_value),
y=value, color=decision)) +
ggplot2::geom_boxplot()+
ggplot2::xlab("")+
ggplot2::ylab("Importance")+
scale_color_discrete(name = "Predictor")+
theme(axis.text.x = element_text(angle = 45, hjust = 1))
} else if (type==1) {
ggplot2::ggplot(impHistory,
aes(x=rnd, y=value,group=name,col=name))+geom_line()+ geom_hline(aes(yintercept=median_value,col=name),lwd=2)+
xlab("Round")+ylab("Importance")+scale_color_discrete(name = "Predictor")
} else {
stop("Unknown graph type. Please choose 0 or 1.")
}

}
9 changes: 4 additions & 5 deletions R/toTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ alls <- unique(alls)


# create table
#covariate.names <-simplify2array(tree$result$btn.matrix[2,])
covariate.names <- getCovariatesFromTree(tree)

# default is to display all parameters
Expand All @@ -79,10 +78,12 @@ all.names <- c(covariate.names, added.param.cols)

str.matrix <- matrix(NA, nrow = length(rowdata),ncol=length(all.names))

# convert to a data frame to avoid coercion to string
str.matrix <- data.frame(str.matrix)

colnames(str.matrix) <- all.names

for (i in 1:length(rowdata)) {
# result.row <- rep(" ",length(covariate.names))
myrow <- rowdata[[i]][[1]]
for (j in 1:length(myrow)) {
myitem <- myrow[[j]]
Expand Down Expand Up @@ -126,7 +127,6 @@ for (i in 1:length(rowdata)) {
}
}

# result.string <- paste(result.string,paste(result.row,collapse="\t"),"\n")
}

## prune empty columns?
Expand All @@ -139,7 +139,7 @@ if (length(is.col.empty)>0) {
sortby <- apply(str.matrix,2,function(x){sum(!is.na(x))})
if (!is.null(added.param.cols)) {
remids <- (dim(str.matrix)[2]-length(added.param.cols)+1):(dim(str.matrix)[2])
sortby[remids] <- sortby[remids]-999999
sortby[remids] <- sortby[remids]-999999 # bad style ^^ #TODO
}
sort.ix <- sort(sortby,index.return=TRUE,decreasing = TRUE)$ix
str.matrix <- str.matrix[, sort.ix]
Expand All @@ -156,7 +156,6 @@ str.matrix[is.na(str.matrix)]<-""
## and display
#cat(result.string)

#require("openxlsx")

return(str.matrix)

Expand Down
31 changes: 27 additions & 4 deletions man/boruta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/semtree.control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 731bebb

Please sign in to comment.