Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/brandmaier/semtree
Browse files Browse the repository at this point in the history
  • Loading branch information
brandmaier committed Jul 4, 2024
2 parents 02d6cc4 + cdca892 commit 3c4997e
Show file tree
Hide file tree
Showing 94 changed files with 708 additions and 237 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/R-CMD-check-windows.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

name: R-CMD-check-Win

jobs:
R-CMD-check:
runs-on: windows-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes
steps:
- uses: actions/checkout@v3

- uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ S3method(summary,semforest)
S3method(summary,semtree)
S3method(toLatex,semtree)
export(biodiversity)
export(boruta)
export(diversityMatrix)
export(evaluateTree)
export(fitSubmodels)
Expand Down Expand Up @@ -65,6 +66,7 @@ export(semforest_score_control)
export(semtree)
export(semtree.constraints)
export(semtree.control)
export(semtree_control)
export(strip)
export(subforest)
export(subtree)
Expand Down
91 changes: 84 additions & 7 deletions R/boruta.R
Original file line number Diff line number Diff line change
@@ -1,35 +1,107 @@
#' BORUTA algorithm for SEM trees
#'
#' This is an experimental feature. Use cautiously.
#'
#' @aliases boruta
#' @param model A template model specification from \code{\link{OpenMx}} using
#' the \code{\link{mxModel}} function (or a \code{\link[lavaan]{lavaan}} model
#' using the \code{\link[lavaan]{lavaan}} function with option fit=FALSE).
#' Model must be syntactically correct within the framework chosen, and
#' converge to a solution.
#' @param data Data.frame used in the model creation using
#' \code{\link{mxModel}} or \code{\link[lavaan]{lavaan}} are input here. Order
#' of modeled variables and predictors is not important when providing a
#' dataset to \code{semtree}.
#' @param control \code{\link{semtree}} model specifications from
#' \code{\link{semtree.control}} are input here. Any changes from the default
#' setting can be specified here.
#' @param percentile_threshold Numeric.
#' @param rounds Numeric. Number of rounds of the BORUTA algorithm.
#'
#' @export
#'
boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
rounds = 1,
...) {
# TODO: make sure that no column names start with "shadow_" prefix

# initial checks
stopifnot(percentile_threshold>=0)
stopifnot(percentile_threshold<=1)
stopifnot(is.numeric(rounds))
stopifnot(rounds>0)

preds_important <- c()
preds_unimportant <- c()

cur_round = 1
temp_vims <- list()

while(cur_round <= rounds) {
vim_boruta <- .boruta(model=model,
data=data,
control=control,
predictors=predictors,
percentile_threshold = percentile_threshold,
...)
browser()
# add predictors to list of unimportant variables
preds_unimportant <- c(preds_unimportant, names(vim_boruta$filter)[!vim_boruta$filter])
# remove them from the dataset
data <- data[, -c(preds_unimportant)]
temp_vims[[cur_round]] <-vim_boruta
}

result <- list(
preds_unimportant,
rounds = rounds
)

return(result)
}

.boruta <- function(model,
data,
control = NULL,
predictors = NULL,
percentile_threshold = 1,
num_shadows = 1,
...) {

# make sure that no column names start with "shadow_" prefix
stopifnot(all(sapply(names(data), function(x) {!startsWith(x, "shadow_")})))

# detect model (warning: duplicated code)
if (inherits(model, "MxModel") || inherits(model, "MxRAMModel")) {
tmp <- getPredictorsOpenMx(mxmodel = model, dataset = data, covariates = predictors)
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]
# } else if (inherits(model,"lavaan")){

# } else if ((inherits(model,"ctsemFit")) || (inherits(model,"ctsemInit"))) {
#
} else if (inherits(model,"lavaan")){

tmp <- getPredictorsLavaan(model, data, predictors)
} else {
ui_stop("Unknown model type selected. Use OpenMx or lavaanified lavaan models!")
}
model.ids <- tmp[[1]]
covariate.ids <- tmp[[2]]

# stage 1 - create shadow features

shadow.ids <- (ncol(data) + 1):(ncol(data) + length(covariate.ids))

for (cur_cov_id in covariate.ids) {
for (rep_id in 1:num_shadows) {
# pick column and shuffle
temp_column <- data[, cur_cov_id]
temp_column <- sample(temp_column, length(temp_column), replace = FALSE)
# add to dataset as shadow feature
temp_colname <- paste0("shadow_", names(data)[cur_cov_id], collapse = "")
if (num_shadows>1) temp_colname <- paste0(temp_colname, rep_id, collapse = "")
data[temp_colname] <- temp_column
if (!is.null(predictors)) predictors <- c(predictors, temp_colname)
}
}

# run the forest
Expand All @@ -41,14 +113,19 @@ boruta <- function(model,
# get variable importance from shadow features
shadow_names <- names(data)[shadow.ids]
agvim <- aggregateVarimp(vim, aggregate = "mean")
max_shadow_importance <- max(agvim[names(agvim) %in% shadow_names])

vals <- agvim[names(agvim) %in% shadow_names]
#max_shadow_importance <- max(vals)
max_shadow_importance <- quantile(vals, percentile_threshold)

agvim_filtered <- agvim[!(names(agvim) %in% shadow_names)]

df <- data.frame(importance = agvim_filtered, predictor = names(agvim_filtered))

vim$filter <- agvim_filtered > max_shadow_importance
vim$boruta <- TRUE
vim$boruta_threshold <- max_shadow_importance
vim$percentile_threshold <- percentile_threshold

return(vim)
}
6 changes: 6 additions & 0 deletions R/checkControl.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ checkControl <- function(control, fail = TRUE) {
check.semtree.control <- function(control, fail = TRUE) {
attr <- attributes(control)$names
def.attr <- attributes(semtree.control())$names

# add NULL-defaults
null_def <- c("min.N","min.bucket","strucchange.to")
attr <- unique(c(attr, null_def))
def.attr <- unique(c(def.attr, null_def))

if ((length(intersect(attr, def.attr)) != length(attr))) {
unknown <- setdiff(attr, def.attr)
msg <-
Expand Down
3 changes: 0 additions & 3 deletions R/checkModel.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,3 @@ checkModel <- function(model, control)

return(TRUE);
}

#inherits(model1,"lavaan")
#model1@Fit@converged
62 changes: 60 additions & 2 deletions R/growTree.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,40 @@ growTree <- function(model = NULL, mydata = NULL,
ui_message("Subsampled predictors: ", paste(node$colnames[meta$covariate.ids]))
}
}

# override forced split?
arguments <- list(...)
if ("forced_splits" %in% names(arguments) && !is.null(arguments$forced_splits)) {
forced_splits <- arguments$forced_splits

# get names of model variables before forcing
model.names <- names(mydata)[meta$model.ids]
covariate.names <- names(mydata)[meta$covariate.ids]

# select subset with model variables and single, forced predictor
forcedsplit.name <- forced_splits[1]

if (control$verbose) {
cat("FORCED split: ",forcedsplit.name,"\n")
}


mydata <- fulldata[, c(model.names, forcedsplit.name) ]
node$colnames <- colnames(mydata)

# get new model ids after sampling by name
meta$model.ids <- sapply(model.names, function(x) {
which(x == names(mydata))
})
names(meta$model.ids) <- NULL
meta$covariate.ids <- unlist(lapply(covariate.names, function(x) {
which(x == names(mydata))
}))

} else {
forced_splits <- NULL
}

# determine whether split evaluation can be done on p values
node$p.values.valid <- control$method != "cv"

Expand Down Expand Up @@ -432,6 +465,31 @@ growTree <- function(model = NULL, mydata = NULL,
mydata <- fulldata
meta <- fullmeta
}

# restore mydata if forced split was true
# and (potentially) force continuation of splitting
if (!is.null(forced_splits)) {


# also need to remap col.max to original data!
if (!is.null(result$col.max) && !is.na(result$col.max)) {
col.max.name <- names(mydata)[result$col.max]
result$col.max <- which(names(fulldata) == col.max.name)
} else {
col.max.name <- NULL
}

mydata <- fulldata
meta <- fullmeta

# pop first element
forced_splits <- forced_splits[-1]
# set to NULL if no splits left
if (length(forced_splits)==0) forced_splits <- NULL

# force continuation of splitting ?
cont.split <- TRUE
}

if ((!is.null(cont.split)) && (!is.na(cont.split)) && (cont.split)) {
if (control$report.level > 10) {
Expand Down Expand Up @@ -563,8 +621,8 @@ growTree <- function(model = NULL, mydata = NULL,

# recursively continue splitting
# result1 - RHS; result2 - LHS
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints)
result2 <- growTree(node$model, sub2, control, invariance, meta, edgelabel = 0, depth = depth + 1, constraints, forced_splits = forced_splits)
result1 <- growTree(node$model, sub1, control, invariance, meta, edgelabel = 1, depth = depth + 1, constraints, forced_splits = forced_splits)

# store results in recursive list structure
node$left_child <- result2
Expand Down
29 changes: 27 additions & 2 deletions R/semtree.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,41 @@ semtree <- function(model, data = NULL, control = NULL, constraints = NULL,
}
}

# here we decide between four cases depending
# on whether min.N is given and/or min.bucket is given
# this is a really dumb heuristic
# please can someone replace this with something more useful
# this based on (Bentler & Chou, 1987; see also Bollen, 1989)

if (is.null(control$min.N)) {
control$min.N <- 5 * npar(model)

if (is.null(control$min.bucket)) {
# both values were not specified
control$min.N <- max(20, 5 * npar(model))
control$min.bucket <- max(10, control$min.N / 2)
} else {
# only min.bucket was given, min.N was not specified
control$min.N <- control$min.bucket * 2
}
} else {
if (is.null(control$min.bucket)) {
# only min.N was given, min.bucket was not specified
control$min.bucket <- max(10, control$min.N / 2)
} else {
# do nothing, both values were specified
if (control$min.bucket > control$min.N) {
warning("Min.bucket parameter should probably be smaller than min.N!")
}
}
}

if (is.null(control$min.N)) {

}

# set min.bucket and min.N heuristically
if (is.null(control$min.bucket)) {
control$min.bucket <- control$min.N / 2

}

if (control$method == "cv") {
Expand Down
6 changes: 2 additions & 4 deletions R/semtree.control.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
#' @export
semtree.control <-
function(method = c("naive","score","fair","fair3"),
min.N = 20,
min.N = NULL,
max.depth = NA,
alpha = .05,
alpha.invariance = NA,
Expand All @@ -119,7 +119,7 @@ semtree.control <-
# ordinal = 'maxLMo', # and maxLM are available
# metric = 'maxLM'),
linear = TRUE,
min.bucket = 10,
min.bucket = NULL,
naive.bonferroni.type = 0,
missing = 'ignore',
use.maxlm = FALSE,
Expand Down Expand Up @@ -200,8 +200,6 @@ semtree.control <-
return(options)
}



#' @export
semtree_control <- function(...) {
semtree.control(...)
Expand Down
9 changes: 6 additions & 3 deletions R/varimp.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ varimp <- function(forest,
colnames(result$importance.level1) <- var.names
}


if (dim(result$importance)[1] == 1) {
#result$importance<-t(result$importance)
result$importance<-t(result$importance)

# TODO: this is stupid, should be as.matrix?! or something else
result$ll.baselines <-
t(t(result$ll.baselines)) # TODO: this is stupid, should be as.matrix?!
}
t(t(result$ll.baselines))
}

colnames(result$importance) <- var.names
result$var.names <- var.names
Expand Down
2 changes: 1 addition & 1 deletion docs/404.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion docs/CONTRIBUTE.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3c4997e

Please sign in to comment.