Skip to content

Commit

Permalink
added quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
brandmaier committed Sep 18, 2024
1 parent 222cb97 commit b9d6ba4
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions R/boruta.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ boruta <- function(model,
)

# TODO: Parallelize the first five runs.
end_time <- NULL
for (runNo in 1:maxRuns) {
start_time <- Sys.time()
if (verbose) {
message(paste("Beginning Run", runNo))
time_info <- ""
if (!is.null(end_time)) time_info <- paste0("Last run took", (end_time-start_time))
message(paste("Beginning Run", runNo," ",time_info))
}

# stage 1 - create shadow features
Expand Down Expand Up @@ -133,9 +137,9 @@ boruta <- function(model,
# Compute shadow stats
shadow_importances <- agvim[names(agvim) %in% shadow_names]
impHistory[runNo, "shadowMax"] <-
max_shadow_importance <- max(shadow_importances)
impHistory[runNo, "shadowMin"] <- min(shadow_importances)
impHistory[runNo, "shadowMean"] <- mean(shadow_importances)
max_shadow_importance <- max(shadow_importances, na.rm=TRUE)
impHistory[runNo, "shadowMin"] <- min(shadow_importances, na.rm=TRUE)
impHistory[runNo, "shadowMean"] <- mean(shadow_importances, na.rm=TRUE)
agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)]
impHistory[runNo, names(agvim_filtered)] <- agvim_filtered

Expand Down Expand Up @@ -172,6 +176,8 @@ boruta <- function(model,
break
}

end_time <- Sys.time()

}

vim$importance <- colMeans(impHistory, na.rm = TRUE)
Expand Down Expand Up @@ -201,22 +207,29 @@ plot.boruta = function(vim, type = 0, ...) {
dplyr::left_join(data.frame(decisionList),
by = dplyr::join_by("name" == "predictor")) |>
dplyr::mutate(decision =
case_when(is.na(decision) ~ "Shadow", .default = decision)) |>
dplyr::case_when(is.na(decision) ~ "Shadow", .default = decision)) |>
dplyr::group_by(name) |>
dplyr::mutate(median_value = median(value, na.rm = TRUE))

if (type == 0) {

# sort Inf values to the left
#impHistory[is.na(impHistory$median_value)] <- -Inf
impHistory$sort_value <- impHistory$median_value
# mv <- min(impHistory$sort_value,na.rm=TRUE)
# impHistory$sort_value[impHistory$decision=="Rejected"]<- (mv-1)

ggplot2::ggplot(impHistory,
ggplot2::aes(
x = stats::reorder(name, median_value),
x = stats::reorder(name, sort_value),
y = value,
color = decision
)) +
ggplot2::geom_boxplot() +
ggplot2::xlab("") +
ggplot2::ylab("Importance") +
ggplot2::scale_color_discrete(name = "Decision") +
ggplot2::theme(axis.text.x = element_text(angle = 45, hjust = 1))
ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1))
} else if (type == 1) {
ggplot2::ggplot(impHistory,
ggplot2::aes(
Expand All @@ -226,7 +239,7 @@ plot.boruta = function(vim, type = 0, ...) {
col = name
)) +
ggplot2::geom_line() +
ggplot2::geom_hline(aes(yintercept = median_value, col = name), lwd =
ggplot2::geom_hline(ggplot2::aes(yintercept = median_value, col = name), lwd =
2) +
ggplot2::xlab("Round") +
ggplot2::ylab("Importance") +
Expand Down

0 comments on commit b9d6ba4

Please sign in to comment.