Skip to content

Commit

Permalink
fix for bug with dummy variables in growTree()
Browse files Browse the repository at this point in the history
  • Loading branch information
brandmaier committed Dec 6, 2023
1 parent cb9289f commit 7cc97e2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
20 changes: 15 additions & 5 deletions R/growTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,16 @@ growTree <- function(model=NULL, mydata=NULL,

# return values in result are:
# LL.max : numeric, log likelihood ratio of best split
# split.max : numeric, value to split best column on
# split.max : numeric, split point; value to split best column on
# (for metric variables)
# col.max : index of best column
# cov.name : name of best candidate
# name.max : name of best candidate
# type.max :
# btn.matrix : a matrix, which contains test statistics and
# more information for
# the various split points evaluated
# n.comp : the (implicit) number of multiple tests evaluated for
# determining the best split

# store the value of the selected test statistic
node$lr <- NA
Expand Down Expand Up @@ -435,18 +442,21 @@ growTree <- function(model=NULL, mydata=NULL,
}
test2 <- test2[,-1]

# if var.type==1, then split.max corresponds to the index of
# if level is categorical, then split.max corresponds to the index of
# the best column in the matrix that represents all subsets
# make sure that this is not casted to a string if there
# are predictors of other types (esp., factors)
# browser()
#browser()
if (!all(is.na(result$btn.matrix))) {
result$split.max <- as.integer(result$split.max)

#named <- colnames(result1$columns)[result$split.max]
# node$caption <- paste(colnames(result1$columns)[result$split.max])
best_subset_col_id = result$split.max
best_values = result1$expressions[ (best_subset_col_id-1)*3 +1]$value

} else {
best_values <- result$split.max
}
node$rule = list(variable=result$col.max, relation="%in%",
value=best_values,
name = result$name.max)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-dummy-split.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set.seed(458)
n <- 500
var_unordered <- factor(sample(c("lightning","rain","sunshine","snow"),n,TRUE))
var_grp <- factor((var_unordered %in% c("rain","sunshine")))
x <- rnorm(n)+ifelse(var_grp,20,0)
x <- rnorm(n)+ifelse(var_grp==TRUE,20,0)

# data frame has only a dummy predictor
df <- data.frame(x=x, var_grp)
Expand Down

0 comments on commit 7cc97e2

Please sign in to comment.