From 7cc97e2bbc1a05602d553d2df132365f3ad2c4dd Mon Sep 17 00:00:00 2001 From: Andreas Brandmaier Date: Wed, 6 Dec 2023 13:52:02 +0100 Subject: [PATCH] fix for bug with dummy variables in growTree() --- R/growTree.R | 20 +++++++++++++++----- tests/testthat/test-dummy-split.R | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/R/growTree.R b/R/growTree.R index 7097a27..50e2e78 100644 --- a/R/growTree.R +++ b/R/growTree.R @@ -304,9 +304,16 @@ growTree <- function(model=NULL, mydata=NULL, # return values in result are: # LL.max : numeric, log likelihood ratio of best split - # split.max : numeric, value to split best column on + # split.max : numeric, split point; value to split best column on + # (for metric variables) # col.max : index of best column - # cov.name : name of best candidate + # name.max : name of best candidate + # type.max : + # btn.matrix : a matrix, which contains test statistics and + # more information for + # the various split points evaluated + # n.comp : the (implicit) number of multiple tests evaluated for + # determining the best split # store the value of the selected test statistic node$lr <- NA @@ -435,18 +442,21 @@ growTree <- function(model=NULL, mydata=NULL, } test2 <- test2[,-1] - # if var.type==1, then split.max corresponds to the index of + # if level is categorical, then split.max corresponds to the index of # the best column in the matrix that represents all subsets # make sure that this is not casted to a string if there # are predictors of other types (esp., factors) - # browser() + #browser() + if (!all(is.na(result$btn.matrix))) { result$split.max <- as.integer(result$split.max) #named <- colnames(result1$columns)[result$split.max] # node$caption <- paste(colnames(result1$columns)[result$split.max]) best_subset_col_id = result$split.max best_values = result1$expressions[ (best_subset_col_id-1)*3 +1]$value - + } else { + best_values <- result$split.max + } node$rule = list(variable=result$col.max, relation="%in%", value=best_values, name = result$name.max) diff --git a/tests/testthat/test-dummy-split.R b/tests/testthat/test-dummy-split.R index 48f0af1..3893b86 100644 --- a/tests/testthat/test-dummy-split.R +++ b/tests/testthat/test-dummy-split.R @@ -5,7 +5,7 @@ set.seed(458) n <- 500 var_unordered <- factor(sample(c("lightning","rain","sunshine","snow"),n,TRUE)) var_grp <- factor((var_unordered %in% c("rain","sunshine"))) -x <- rnorm(n)+ifelse(var_grp,20,0) +x <- rnorm(n)+ifelse(var_grp==TRUE,20,0) # data frame has only a dummy predictor df <- data.frame(x=x, var_grp)