Skip to content

Commit

Permalink
removed old effect mod functions
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-contours committed Apr 6, 2024
1 parent 3134d16 commit 5054cc7
Show file tree
Hide file tree
Showing 5 changed files with 0 additions and 313 deletions.
5 changes: 0 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
export(InterXshift)
export(calc_CIs)
export(calc_basis_freq)
export(calc_final_effect_mod_param)
export(calc_final_ind_shift_param)
export(calc_final_joint_shift_param)
export(calc_intxn_results)
Expand All @@ -24,7 +23,6 @@ export(indiv_stoch_shift_est_Q)
export(indiv_stoch_shift_est_g_exp)
export(joint_stoch_shift_est_Q)
export(joint_stoch_shift_est_g_exp)
export(list_rules_party)
export(quad_integrand_q_g_r)
export(scale_to_original)
export(scale_to_unit)
Expand All @@ -47,11 +45,9 @@ importFrom(data.table,setnames)
importFrom(dplyr,bind_rows)
importFrom(dplyr,filter)
importFrom(dplyr,group_by)
importFrom(dplyr,mutate)
importFrom(dplyr,top_n)
importFrom(foreach,"%dopar%")
importFrom(magrittr,"%>%")
importFrom(partykit,glmtree)
importFrom(purrr,is_empty)
importFrom(purrr,map)
importFrom(rlang,":=")
Expand All @@ -66,7 +62,6 @@ importFrom(stats,cov)
importFrom(stats,fitted)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,median)
importFrom(stats,model.matrix)
importFrom(stats,p.adjust)
importFrom(stats,plogis)
Expand Down
163 changes: 0 additions & 163 deletions R/final_result_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,169 +107,6 @@ calc_final_ind_shift_param <- function(tmle_fit, exposure, fold_k) {
return(results)
}

#' Calculates the Effect Modification Shift Parameter
#'
#' @description This function uses a decision tree estimator to regress the difference in expected outcomes
#' of an exposure onto a covariate. Both the exposure and the covariate have been determined through data-adaptive methods.
#' If no rules are found, a median split is applied.
#'
#' @param tmle_fit_av TMLE results for the validation fold.
#' @param tmle_fit_at TMLE results for the training fold.
#' @param exposure The identified exposure variable as a \code{character} string.
#' @param at Training dataset.
#' @param av Validation dataset.
#' @param effect_m_name The name of the effect modifier variable as a \code{character} string.
#' @param fold_k The fold in which the effect modification was found.
#' @importFrom partykit glmtree
#' @importFrom stats median
#' @importFrom dplyr mutate
#' @export
#'

calc_final_effect_mod_param <- function(tmle_fit_av,
tmle_fit_at,
exposure,
at,
av,
effect_m_name,
fold_k) {
# Calculate pseudo_outcome for training and validation datasets
pseudo_outcome_at <- tmle_fit_at$qn_shift_star - tmle_fit_at$qn_noshift_star
pseudo_outcome_av <- tmle_fit_av$qn_shift_star - tmle_fit_av$qn_noshift_star

# Rename columns
names(pseudo_outcome_at) <- "pseudo_outcome_at"
names(pseudo_outcome_av) <- "pseudo_outcome_av"

# Prepare data for the tree model
em_model_data_at <- cbind.data.frame(
pseudo_outcome_at, subset(at,
select = effect_m_name
)
)

em_model_data_av <- cbind.data.frame(
pseudo_outcome_av, subset(av,
select = effect_m_name
)
)

# Create the formula for the tree model
response_var <- "pseudo_outcome_at"
formula_string <- paste(response_var, "~", effect_m_name)
formula <- as.formula(formula_string)

# Fit the tree model
tree_model <- glmtree(formula, data = em_model_data_at)

# Extract the rules from the tree model
rules <- list_rules_party(tree_model)

# If no rules are found, use median split
if (all(rules == "") == TRUE) {
# Determine if the effect modifier is binary (0 or 1)
is_binary <- all(apply(at[, ..effect_m_name], 2, function(x) {
all(x %in% 0:1)
}))

# Calculate medians for the effect modifier
medians <- apply(at[, ..effect_m_name], 2, median)

# Create median rules
median_rule_list <- lapply(seq_along(medians), function(i) {
if (is_binary) {
paste(names(medians)[i], "==", medians[i])
} else {
paste(names(medians)[i], "<=", medians[i])
}
})

rules <- median_rule_list
}

# Initialize results_list
results_list <- list()

# Loop through the rules
for (i in seq_along(rules)) {
rule <- rules[[i]]

# Apply the rule to split the data
em_split_data <- em_model_data_av %>%
dplyr::mutate(ind = ifelse(eval(parse(text = rule)), 1, 0))

# If the rule does not partition the validation sample and leads to all one value,
# skip to the next rule
if (length(unique(em_split_data$ind)) == 1) {
next()
}

# Calculate the inverse propensity weights
inverse_prop_positive <- ifelse(em_split_data$ind == 1,
1 / (table(em_split_data$ind)[[2]] /
length(em_split_data$ind)), 0
)

inverse_prop_negative <- ifelse(em_split_data$ind == 0,
1 / (table(em_split_data$ind)[[1]]
/ length(em_split_data$ind)), 0
)

# Calculate the weighted EIFs
inverse_prop_eif_pos <- inverse_prop_positive * (tmle_fit_av$eif - tmle_fit_av$noshift_eif)
inverse_prop_eif_neg <- inverse_prop_negative * (tmle_fit_av$eif - tmle_fit_av$noshift_eif)

# Calculate the shift difference for both groups
diff <- tmle_fit_av$qn_shift_star - tmle_fit_av$qn_noshift_star
psi_em_one <- mean(diff[em_split_data$ind == 1])
psi_em_zero <- mean(diff[em_split_data$ind == 0])

# Calculate variance and standard error
psi_one_var <- var(inverse_prop_eif_pos[em_split_data$ind == 1]) /
table(em_split_data$ind)[[2]]

psi_zero_var <- var(inverse_prop_eif_neg[em_split_data$ind == 0]) /
table(em_split_data$ind)[[1]]

# Calculate confidence intervals
em_one_ci <- calc_CIs(psi_em_one, sqrt(psi_one_var))
em_zero_ci <- calc_CIs(psi_em_zero, sqrt(psi_zero_var))

# Calculate p-values
level_1_p_val <- calc_pvals(psi_em_one, psi_one_var)
level_0_p_val <- calc_pvals(psi_em_zero, psi_zero_var)

# Store results in a data frame
results <- data.table::data.table(
Condition = c(
paste("Level 1 Shift Diff in", rule),
paste("Level 0 Shift Diff in", rule)
),
Psi = c(psi_em_one, psi_em_zero),
Variance = c(psi_one_var, psi_zero_var),
SE = c(sqrt(psi_one_var), sqrt(psi_zero_var)),
Lower_CI = c(em_one_ci[1], em_zero_ci[1]),
Upper_CI = c(em_one_ci[2], em_zero_ci[2]),
P_value = c(level_1_p_val, level_0_p_val),
Fold = fold_k,
Type = "Effect Mod",
Variables = paste(exposure, effect_m_name, sep = ""),
N = length(tmle_fit_av$eif)
)

# Add the results to the results_list
results_list[[i]] <- results
}

# Combine all results into a single data frame
results_df <- do.call(rbind, results_list)

# Remove duplicated results
results_df <- results_df[!duplicated(results_df$Psi), ]

return(results_df)
}


#' @title Calculates the Joint Shift Parameter
#' @description Estimates the shift parameter for a joint shift
Expand Down
88 changes: 0 additions & 88 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,91 +135,3 @@ is.InterXshift <- function(x) {
is.InterXshift <- function(x) {
class(x) == "InterXshift_msm"
}

###################################################################
#' @title Get rules from partykit object in rule fitting
#' @param x Partykit glmtree model object
#' @param i null
#' @param ... additional arguments
#' @return List of rules
#'
#' @export
# Copied from internal partykit function
list_rules_party <- function(x, i = NULL, ...) {
if (is.null(i)) {
i <- partykit::nodeids(x, terminal = TRUE)
}
if (length(i) > 1) {
ret <- sapply(i, list_rules_party, x = x)
names(ret) <- if (is.character(i)) {
i
} else {
names(x)[i]
}
return(ret)
}
if (is.character(i) && !is.null(names(x))) {
i <- which(names(x) %in% i)
}
stopifnot(length(i) == 1 & is.numeric(i))
stopifnot(i <= length(x) & i >= 1)
i <- as.integer(i)
dat <- partykit::data_party(x, i)
if (!is.null(x$fitted)) {
findx <- which("(fitted)" == names(dat))[1]
dat <- dat[, -(findx:ncol(dat)), drop = FALSE]
if (ncol(dat) == 0) {
dat <- x$data
}
} else {
dat <- x$data
}
rule <- c()
rec_fun <- function(node) {
if (partykit::id_node(node) == i) {
return(NULL)
}
kid <- sapply(partykit::kids_node(node), partykit::id_node)
whichkid <- max(which(kid <= i))
split <- partykit::split_node(node)
ivar <- partykit::varid_split(split)
svar <- names(dat)[ivar]
index <- partykit::index_split(split)
if (is.factor(dat[, svar])) {
if (is.null(index)) {
index <- ((1:nlevels(dat[, svar])) > partykit::breaks_split(split)) +
1
}
slevels <- levels(dat[, svar])[index == whichkid]
srule <- paste(svar, " %in% c(\"", paste(slevels,
collapse = "\", \"", sep = ""
), "\")", sep = "")
} else {
if (is.null(index)) {
index <- seq_along(kid)
}
breaks <- cbind(c(-Inf, partykit::breaks_split(split)), c(
partykit::breaks_split(split),
Inf
))
sbreak <- breaks[index == whichkid, ]
right <- partykit::right_split(split)
srule <- c()
if (is.finite(sbreak[1])) {
srule <- c(srule, paste(svar, ifelse(right, ">",
">="
), sbreak[1]))
}
if (is.finite(sbreak[2])) {
srule <- c(srule, paste(svar, ifelse(right, "<=",
"<"
), sbreak[2]))
}
srule <- paste(srule, collapse = " & ")
}
rule <<- c(rule, srule)
return(rec_fun(node[[whichkid]]))
}
node <- rec_fun(partykit::node_party(x))
paste(rule, collapse = " & ")
}
36 changes: 0 additions & 36 deletions man/calc_final_effect_mod_param.Rd

This file was deleted.

21 changes: 0 additions & 21 deletions man/list_rules_party.Rd

This file was deleted.

0 comments on commit 5054cc7

Please sign in to comment.