diff --git a/R/checkControl.R b/R/checkControl.R index 4bcd8f1..09e5c21 100644 --- a/R/checkControl.R +++ b/R/checkControl.R @@ -11,6 +11,12 @@ checkControl <- function(control, fail = TRUE) { check.semtree.control <- function(control, fail = TRUE) { attr <- attributes(control)$names def.attr <- attributes(semtree.control())$names + + # add NULL-defaults + null_def <- c("min.N","min.bucket") + attr <- unique(c(attr, null_def)) + def.attr <- unique(c(def.attr, null_def)) + if ((length(intersect(attr, def.attr)) != length(attr))) { unknown <- setdiff(attr, def.attr) msg <- diff --git a/R/checkModel.R b/R/checkModel.R index bd2ded8..2c11479 100644 --- a/R/checkModel.R +++ b/R/checkModel.R @@ -24,6 +24,3 @@ checkModel <- function(model, control) return(TRUE); } - -#inherits(model1,"lavaan") -#model1@Fit@converged diff --git a/R/growTree.R b/R/growTree.R index e47c4a3..630426b 100644 --- a/R/growTree.R +++ b/R/growTree.R @@ -64,7 +64,40 @@ growTree <- function(model = NULL, mydata = NULL, ui_message("Subsampled predictors: ", paste(node$colnames[meta$covariate.ids])) } } + + # override forced split? + arguments <- list(...) + if ("forced_splits" %in% names(arguments) && !is.null(arguments$forced_splits)) { + forced_splits <- arguments$forced_splits + + # get names of model variables before forcing + model.names <- names(mydata)[meta$model.ids] + covariate.names <- names(mydata)[meta$covariate.ids] + + # select subset with model variables and single, forced predictor + forcedsplit.name <- forced_splits[1] + + if (control$verbose) { + cat("FORCED split: ",forcedsplit.name,"\n") + } + + mydata <- fulldata[, c(model.names, forcedsplit.name) ] + node$colnames <- colnames(mydata) + + # get new model ids after sampling by name + meta$model.ids <- sapply(model.names, function(x) { + which(x == names(mydata)) + }) + names(meta$model.ids) <- NULL + meta$covariate.ids <- unlist(lapply(covariate.names, function(x) { + which(x == names(mydata)) + })) + + } else { + forced_splits <- NULL + } + # determine whether split evaluation can be done on p values node$p.values.valid <- control$method != "cv" @@ -432,6 +465,31 @@ growTree <- function(model = NULL, mydata = NULL, mydata <- fulldata meta <- fullmeta } + + # restore mydata if forced split was true + # and (potentially) force continuation of splitting + if (!is.null(forced_splits)) { + + + # also need to remap col.max to original data! + if (!is.null(result$col.max) && !is.na(result$col.max)) { + col.max.name <- names(mydata)[result$col.max] + result$col.max <- which(names(fulldata) == col.max.name) + } else { + col.max.name <- NULL + } + + mydata <- fulldata + meta <- fullmeta + + # pop first element + forced_splits <- forced_splits[-1] + # set to NULL if no splits left + if (length(forced_splits)==0) forced_splits <- NULL + + # force continuation of splitting ? + cont.split <- TRUE + } if ((!is.null(cont.split)) && (!is.na(cont.split)) && (cont.split)) { if (control$report.level > 10) { @@ -563,8 +621,8 @@ growTree <- function(model = NULL, mydata = NULL, # recursively continue splitting # result1 - RHS; result2 - LHS - result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints) - result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints) + result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints, forced_splits = forced_splits) + result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints, forced_splits = forced_splits) # store results in recursive list structure node$left_child <- result2 diff --git a/tests/control.R b/tests/control.R index a247aa7..d7b4465 100644 --- a/tests/control.R +++ b/tests/control.R @@ -80,4 +80,15 @@ controlOptions <- semtree.control(method = "naive",max.depth = 0,min.N=NULL, tree <- semtree(model=lgcModel, data=lgcm, control = controlOptions) stopifnot(tree$control$min.N==50) -stopifnot(tree$control$min.bucket==25) \ No newline at end of file +stopifnot(tree$control$min.bucket==25) + + + +x<-semtree_control() +semtree:::check.semtree.control(x) + +x<-semtree_control(min.N=100) +semtree:::check.semtree.control(x) + +x<-semtree_control(min.N=100, min.bucket=10) +semtree:::check.semtree.control(x) diff --git a/tests/testthat/forced_splitl.R b/tests/testthat/forced_splitl.R new file mode 100644 index 0000000..b8cbdf7 --- /dev/null +++ b/tests/testthat/forced_splitl.R @@ -0,0 +1,28 @@ +library(lavaan) +library(semtree) +set.seed(1238) + +N <- 500 + +# simulate data +da <- data.frame(y = c(rnorm(N/2, mean = -1), rnorm(N/2, mean = 1)), + z = factor(rep(c(0,1),each=N/2)),k=rnorm(N),m=rnorm(N) ) + +m_lav <- ' +y ~~ y +y ~ 1 +' + +fit_lav <- lavaan(model = m_lav, data = da) + + +tree = semtree(model=fit_lav, data=da, + control = semtree_control(method="score"), + forced_splits=NULL) + + + +tree_forced_m = semtree(model=fit_lav, data=da, + control = semtree_control(method="score"), + forced_splits=c("m")) +