Skip to content

Commit

Permalink
adding forced_split functionality proposed by CvL
Browse files Browse the repository at this point in the history
  • Loading branch information
brandmaier committed Apr 17, 2024
1 parent 45f3cce commit cffd547
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 2 deletions.
62 changes: 60 additions & 2 deletions R/growTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,40 @@ growTree <- function(model = NULL, mydata = NULL,
ui_message("Subsampled predictors: ", paste(node$colnames[meta$covariate.ids]))
}
}

# override forced split?
arguments <- list(...)
if ("forced_splits" %in% names(arguments) && !is.null(arguments$forced_splits)) {
forced_splits <- arguments$forced_splits

# get names of model variables before forcing
model.names <- names(mydata)[meta$model.ids]
covariate.names <- names(mydata)[meta$covariate.ids]

# select subset with model variables and single, forced predictor
forcedsplit.name <- forced_splits[1]

if (control$verbose) {
cat("FORCED split: ",forcedsplit.name,"\n")
}


mydata <- fulldata[, c(model.names, forcedsplit.name) ]
node$colnames <- colnames(mydata)

# get new model ids after sampling by name
meta$model.ids <- sapply(model.names, function(x) {
which(x == names(mydata))
})
names(meta$model.ids) <- NULL
meta$covariate.ids <- unlist(lapply(covariate.names, function(x) {
which(x == names(mydata))
}))

} else {
forced_splits <- NULL
}

# determine whether split evaluation can be done on p values
node$p.values.valid <- control$method != "cv"

Expand Down Expand Up @@ -432,6 +465,31 @@ growTree <- function(model = NULL, mydata = NULL,
mydata <- fulldata
meta <- fullmeta
}

# restore mydata if forced split was true
# and (potentially) force continuation of splitting
if (!is.null(forced_splits)) {


# also need to remap col.max to original data!
if (!is.null(result$col.max) && !is.na(result$col.max)) {
col.max.name <- names(mydata)[result$col.max]
result$col.max <- which(names(fulldata) == col.max.name)
} else {
col.max.name <- NULL
}

mydata <- fulldata
meta <- fullmeta

# pop first element
forced_splits <- forced_splits[-1]
# set to NULL if no splits left
if (length(forced_splits)==0) forced_splits <- NULL

# force continuation of splitting ?
cont.split <- TRUE
}

if ((!is.null(cont.split)) && (!is.na(cont.split)) && (cont.split)) {
if (control$report.level > 10) {
Expand Down Expand Up @@ -563,8 +621,8 @@ growTree <- function(model = NULL, mydata = NULL,

# recursively continue splitting
# result1 - RHS; result2 - LHS
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints)
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints, forced_splits = forced_splits)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints, forced_splits = forced_splits)

# store results in recursive list structure
node$left_child <- result2
Expand Down
28 changes: 28 additions & 0 deletions tests/testthat/forced_splitl.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
library(lavaan)
library(semtree)
set.seed(1238)

N <- 500

# simulate data
da <- data.frame(y = c(rnorm(N/2, mean = -1), rnorm(N/2, mean = 1)),
z = factor(rep(c(0,1),each=N/2)),k=rnorm(N),m=rnorm(N) )

m_lav <- '
y ~~ y
y ~ 1
'

fit_lav <- lavaan(model = m_lav, data = da)


tree = semtree(model=fit_lav, data=da,
control = semtree_control(method="score"),
forced_splits=NULL)



tree_forced_m = semtree(model=fit_lav, data=da,
control = semtree_control(method="score"),
forced_splits=c("m"))

0 comments on commit cffd547

Please sign in to comment.