Skip to content

Commit

Permalink
Add rf_hypopt()
Browse files Browse the repository at this point in the history
  • Loading branch information
kantonopoulos committed Jul 12, 2024
1 parent 70330d2 commit 07a7c2c
Showing 1 changed file with 93 additions and 9 deletions.
102 changes: 93 additions & 9 deletions R/classification_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ vis_hypopt <- function(tune_res,
#' @param exclude_cols (vector). Columns to exclude from the model. Default is NULL.
#' @param disease (character). Disease to predict.
#' @param type (character). Type of regularization. Default is "lasso". Other options are "ridge" and "elnet".
#' @param metric (function). Metric to optimize. Default is roc_auc.
#' @param cv_sets (numeric). Number of cross-validation sets. Default is 5.
#' @param grid_size (numeric). Size of the grid for hyperparameter optimization. Default is 10.
#' @param ncores (numeric). Number of cores to use for parallel processing. Default is 4.
Expand All @@ -184,7 +183,6 @@ elnet_hypopt <- function(train_data,
test_data,
disease,
type = "lasso",
metric = roc_auc,
cv_sets = 5,
grid_size = 10,
ncores = 4,
Expand Down Expand Up @@ -263,14 +261,98 @@ elnet_hypopt <- function(train_data,
hypopt_plot <- vis_hypopt(elnet_tune, "penalty", NULL, disease)
}
return(list("elnet_tune" = elnet_tune,
"wf" = elnet_wf,
"elnet_wf" = elnet_wf,
"train_set" = train_set,
"test_set" = test_set,
"hyperopt_vis" = hypopt_plot))
}

return(list("elnet_tune" = elnet_tune,
"wf" = elnet_wf,
"elnet_wf" = elnet_wf,
"train_set" = train_set,
"test_set" = test_set))
}


rf_hypopt <- function(train_data,
test_data,
disease,
cv_sets = 5,
grid_size = 10,
ncores = 4,
hypopt_vis = TRUE,
exclude_cols = NULL,
seed = 123
) {

if (ncores > 1) {
doParallel::registerDoParallel(cores = ncores)
}

# Prepare train data and create cross-validation sets with binary classifier
train_set <- train_data[[disease]] |>
dplyr::mutate(Disease = ifelse(Disease == disease, 1, 0)) |>
dplyr::mutate(Disease = as.factor(Disease)) |>
dplyr::select(-dplyr::any_of(exclude_cols)) |>
dplyr::mutate(dplyr::across(where(is.character), as.factor))

train_folds <- rsample::vfold_cv(train_set, v = cv_sets, strata = Disease)

test_set <- test_data[[disease]] |>
dplyr::mutate(Disease = ifelse(Disease == disease, 1, 0)) |>
dplyr::mutate(Disease = as.factor(Disease)) |>
dplyr::select(-dplyr::any_of(exclude_cols))

rf_rec <- recipes::recipe(Disease ~ ., data = train_set) |>
recipes::update_role(DAid, new_role = "id") |>
recipes::step_normalize(recipes::all_numeric()) |>
recipes::step_nzv(recipes::all_numeric()) |>
recipes::step_corr(recipes::all_numeric()) |>
recipes::step_impute_knn(recipes::all_numeric())

rf_spec <- parsnip::rand_forest(
trees = 1000,
mtry = tune::tune(),
min_n = tune::tune()
) |>
parsnip::set_mode("classification") |>
parsnip::set_engine("ranger", importance = "permutation")

disease_pred <- train_set |> dplyr::select(-dplyr::any_of(c("Disease", "DAid", "Sex", "Age", "BMI")))

rf_wf <- workflows::workflow() |>
workflows::add_model(rf_spec) |>
workflows::add_recipe(rf_rec)

rf_grid <- rf_wf |>
workflows::extract_parameter_set_dials() |>
dials::finalize(disease_pred) |>
dials::grid_latin_hypercube(size = grid_size)

roc_res <- yardstick::metric_set(yardstick::roc_auc)

set.seed(seed)
ctrl <- tune::control_grid(save_pred = TRUE, parallel_over = "everything")
rf_tune <- rf_wf |>
tune::tune_grid(
train_folds,
grid = rf_grid,
control = ctrl,
metrics = roc_res
)

if (hypopt_vis) {
hypopt_plot <- vis_hypopt(rf_tune, "min_n", "mtry", disease)

return(list("rf_tune" = rf_tune,
"rf_wf" = rf_wf,
"train_set" = train_set,
"test_set" = test_set,
"hyperopt_vis" = hypopt_plot))
}

return(list("rf_tune" = rf_tune,
"rf_wf" = rf_wf,
"train_set" = train_set,
"test_set" = test_set))
}
Expand All @@ -290,10 +372,10 @@ elnet_hypopt <- function(train_data,
#' - best_elnet (tibble). Best hyperparameters from hyperparameter optimization.
#' - final_wf (workflow). Final workflow object.
#' @keywords internal
elnet_finalfit <- function(train_set,
tune_res,
wf,
seed = 123) {
finalfit <- function(train_set,
tune_res,
wf,
seed = 123) {

best_elnet <- tune_res |>
tune::select_best(metric = "roc_auc") |>
Expand Down Expand Up @@ -639,7 +721,7 @@ do_elnet <- function(olink_data,

finalfit_res <- elnet_finalfit(hypopt_res$train_set,
hypopt_res$elnet_tune,
hypopt_res$wf,
hypopt_res$elnet_wf,
seed)

testfit_res <- elnet_testfit(hypopt_res$train_set,
Expand Down Expand Up @@ -684,3 +766,5 @@ do_elnet <- function(olink_data,

return(elnet_results)
}


0 comments on commit 07a7c2c

Please sign in to comment.