Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supports general submodels and a bunch of new parameters #75

Open
wants to merge 66 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
5552f70
submodels
Sep 5, 2021
15d0089
submodels
Sep 5, 2021
8d86c53
supports general submodels. Tests pass
Sep 5, 2021
4ceae2c
supports general submodel and weights. JK on previous commit
Sep 5, 2021
97b765e
supports general submodel and weights. JK on previous commit
Sep 5, 2021
3d252ac
supports general submodel and weights. JK on previous commit
Sep 5, 2021
a7a1dda
fix ATT and ATC submodels
Sep 5, 2021
742f1c9
REFERENCE CHANGES HERE
Sep 5, 2021
64583b0
towards spCATE
Sep 6, 2021
5a80be0
implemented spCATE PARAM
Sep 6, 2021
870e1a1
added more sp params
Sep 6, 2021
ceac71b
causalGLM seems to work
Sep 6, 2021
ec88bb9
hi
Sep 6, 2021
7ba9492
Update tmle3_Update.R
Larsvanderlaan Sep 6, 2021
9b64c82
Remove built in submodels/losses
Larsvanderlaan Sep 6, 2021
01b1116
Delete submodels_semiparametric.R
Larsvanderlaan Sep 6, 2021
a22e578
np params
Sep 6, 2021
a8ca892
Merge branch 'general_submodels_devel' of https://github.com/tlverse/…
Sep 6, 2021
1142b6b
more np
Sep 6, 2021
a4584d1
minr bug fixes to npOR
Sep 6, 2021
6bd79b4
testing
Sep 6, 2021
0f15fc0
hi
Sep 6, 2021
b017418
more tests
Sep 6, 2021
eaff276
more tests
Sep 6, 2021
d9f2ee3
ran make style
Sep 6, 2021
e97a39f
ran make style
Sep 6, 2021
667e24c
wait
Sep 6, 2021
2d3aa6e
fix bug
Sep 6, 2021
e349397
format
Sep 6, 2021
de1a3f1
format
Sep 6, 2021
759dca0
fix documentation bug
Sep 6, 2021
d90cb6e
remove glm_sp docs
Sep 6, 2021
63a41a1
changes
Sep 6, 2021
fca50f6
changes
Sep 6, 2021
7253a6c
changes
Sep 6, 2021
436581f
change to default for spCausal
Sep 6, 2021
03f22c7
change to default for spCausal
Sep 6, 2021
c19995e
change to default for spCausal
Sep 6, 2021
e4ee0f5
change to default for spCausal
Sep 6, 2021
da0e4a8
change to default for spCausal
Sep 7, 2021
fc438e4
change to default for spCausal
Sep 7, 2021
e218916
fix bug tmle3_fit initial est if no full fit
Sep 7, 2021
e094a76
fix
Sep 7, 2021
701b5a5
fix
Sep 7, 2021
2fff29b
fix
Sep 7, 2021
2244687
fix
Sep 7, 2021
59a9301
fix
Sep 7, 2021
9808493
add npTSM
Sep 7, 2021
3f5b0f7
npRR
Sep 7, 2021
9c55096
bounded outcomes
Sep 7, 2021
d75be43
bounded outcomes
Sep 7, 2021
ed89de8
bounded outcomes
Sep 7, 2021
469e7f8
bounded outcomes
Sep 7, 2021
4d4fb2f
plz
Sep 7, 2021
48324de
small
Sep 7, 2021
ab16187
added Param_coxph
Sep 8, 2021
f019a0f
added Param_coxph
Sep 8, 2021
1fcd721
added Param_coxph
Sep 8, 2021
a6b3587
change OR default submodel
Sep 8, 2021
7969124
sort of fixed coxph
Sep 9, 2021
e9d587f
sort of fixed coxph
Sep 9, 2021
3a3bfa8
hi
Sep 9, 2021
d340dc9
hi
Sep 9, 2021
fef3613
plz dont break
Sep 9, 2021
0787f9b
ATE weights
Larsvanderlaan Apr 7, 2022
dd189d8
ATE weights
Larsvanderlaan Apr 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ Encoding: UTF-8
LazyData: yes
LazyLoad: yes
VignetteBuilder: knitr
RoxygenNote: 7.1.1.9001
RoxygenNote: 7.1.2
Roxygen: list(markdown = TRUE, r6 = FALSE)
23 changes: 22 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@ export(LF_static)
export(LF_targeted)
export(Likelihood)
export(Likelihood_cache)
export(Lrnr_glm_semiparametric)
export(Param_ATC)
export(Param_ATE)
export(Param_ATT)
export(Param_MSM)
export(Param_TSM)
export(Param_base)
export(Param_coxph)
export(Param_delta)
export(Param_mean)
export(Param_npCATE)
export(Param_npCATT)
export(Param_npOR)
export(Param_npRR)
export(Param_npTSM)
export(Param_spCATE)
export(Param_spOR)
export(Param_spRR)
export(Param_stratified)
export(Param_survival)
export(Targeted_Likelihood)
Expand All @@ -35,9 +45,16 @@ export(delta_param_RR)
export(density_formula)
export(discretize_variable)
export(fit_tmle3)
export(generate_loss_function_from_family)
export(generate_submodel_from_family)
export(get_propensity_scores)
export(get_submodel_spec)
export(loss_loglik)
export(loss_loglik_binomial)
export(loss_poisson)
export(make_CF_Likelihood)
export(make_Likelihood)
export(make_submodel_spec)
export(make_tmle3_Task)
export(plot_vim)
export(point_tx_likelihood)
Expand All @@ -46,7 +63,8 @@ export(point_tx_task)
export(process_missing)
export(propensity_score_plot)
export(propensity_score_table)
export(submodel_logit)
export(submodel_logistic_switch)
export(submodel_spec_logistic_switch)
export(summary_from_estimates)
export(survival_tx_likelihood)
export(survival_tx_npsem)
Expand All @@ -64,6 +82,9 @@ export(tmle3_Spec_OR)
export(tmle3_Spec_PAR)
export(tmle3_Spec_RR)
export(tmle3_Spec_TSM_all)
export(tmle3_Spec_coxph)
export(tmle3_Spec_npCausalGLM)
export(tmle3_Spec_spCausalGLM)
export(tmle3_Spec_stratified)
export(tmle3_Spec_survival)
export(tmle3_Task)
Expand Down
21 changes: 17 additions & 4 deletions R/LF_known.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,29 @@ LF_known <- R6Class(
class = TRUE,
inherit = LF_base,
public = list(
initialize = function(name, mean_fun = stub_known, density_fun = stub_known, ..., type = "density") {
initialize = function(name, mean_fun = stub_known, density_fun = stub_known, base_likelihood = NULL, ..., type = "density") {
super$initialize(name, ..., type = type)
private$.mean_fun <- mean_fun
private$.density_fun <- density_fun
private$.base_likelihood <- base_likelihood
},
get_mean = function(tmle_task, fold_number) {
learner_task <- tmle_task$get_regression_task(self$name, scale = FALSE)
preds <- self$mean_fun(learner_task)
if (!is.null(self$base_likelihood)) {
preds <- self$mean_fun(learner_task, tmle_task, self$base_likelihood)
} else {
preds <- self$mean_fun(learner_task)
}

return(preds)
},
get_density = function(tmle_task, fold_number) {
learner_task <- tmle_task$get_regression_task(self$name, scale = FALSE)
preds <- self$density_fun(learner_task)
if (!is.null(self$base_likelihood)) {
preds <- self$density_fun(learner_task, tmle_task, self$base_likelihood)
} else {
preds <- self$density_fun(learner_task)
}

outcome_type <- learner_task$outcome_type
observed <- outcome_type$format(learner_task$Y)
Expand All @@ -79,11 +88,15 @@ LF_known <- R6Class(

density_fun = function() {
return(private$.density_fun)
},
base_likelihood = function() {
return(private$.base_likelihood)
}
),
private = list(
.name = NULL,
.mean_fun = NULL,
.density_fun = NULL
.density_fun = NULL,
.base_likelihood = NULL
)
)
137 changes: 137 additions & 0 deletions R/Lrnr_glm_semiparametric.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#' Semiparametric Generalized Linear Models
#'
#' This learner provides fitting procedures for semiparametric generalized linear models using a user-given baseline learner and
#' \code{\link[stats]{glm.fit}}. It supports models of the form `linkfun(E[Y|A,W]) = linkfun(E[Y|A=0,W]) + A * f(W)` where `A` is a binary or continuous interaction variable,
#' and `f(W)` is a user-specified parametric function (e.g. `f(W) = model.matrix(formula_sp, W)`). The baseline function `E[Y|A=0,W]` is fit using a user-specified Learner (possibly pooled over values of `A` and then projected onto the semiparametric model).
#'
#' @export
Lrnr_glm_semiparametric <- R6Class(
classname = "Lrnr_glm_semiparametric", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
public = list(
initialize = function(formula_sp, lrnr_baseline, interaction_variable = "A", family = NULL, append_interaction_matrix = TRUE, return_matrix_predictions = FALSE, ...) {
params <- args_to_list()
super$initialize(params = params, ...)
}
),

private = list(
.properties = c("continuous", "binomial", "semiparametric", "weights"),

.train = function(task) {
args <- self$params
append_interaction_matrix <- args$append_interaction_matrix
outcome_type <- self$get_outcome_type(task)
trt <- args$interaction_variable
if (is.null(trt)) {
A <- rep(1, task$nrow)
} else {
A <- unlist(task$get_data(, trt))
}
if (!all(A %in% c(0, 1)) && !is.null(trt)) {
binary <- FALSE
} else {
binary <- TRUE
}
family <- args$family
lrnr_baseline <- args$lrnr_baseline
formula <- args$formula_sp
if (is.null(family)) {
family <- outcome_type$glm_family(return_object = TRUE)
}
# Interaction design matrix
Y <- task$Y
V <- model.matrix(formula, task$data)
colnames(V) <- paste0("V", 1:ncol(V))

covariates <- setdiff(task$nodes$covariates, trt)

if (!append_interaction_matrix && binary) {
task_baseline <- task$next_in_chain(covariates = covariates)
lrnr_baseline <- lrnr_baseline$train(task_baseline[A == 0])
Q0 <- lrnr_baseline$predict(task_baseline)
beta <- suppressWarnings(coef(glm.fit(A * V, Y, offset = family$linkfun(Q0), intercept = F, weights = task$weights, family = family)))
Q1 <- family$linkinv(family$linkfun(Q0) + V %*% beta)
Q <- ifelse(A == 1, Q1, Q0)
} else {
covariates <- setdiff(task$nodes$covariates, trt)

if (append_interaction_matrix) {
AV <- as.data.table(A * V)
X <- cbind(task$X[, covariates, with = F], AV)
X0 <- cbind(task$X[, covariates, with = F], 0 * V)
} else {
X <- cbind(task$X[, covariates, with = F], A)
X0 <- cbind(task$X[, covariates, with = F], A * 0)
}


column_names <- task$add_columns(X)
task_baseline <- task$next_in_chain(covariates = colnames(X), column_names = column_names)

column_names <- task$add_columns(X0)
task_baseline0 <- task$next_in_chain(covariates = colnames(X0), column_names = column_names)

lrnr_baseline <- lrnr_baseline$train(task_baseline)
Q <- lrnr_baseline$predict(task_baseline)
Q0 <- lrnr_baseline$predict(task_baseline0)
# Project onto model

beta <- suppressWarnings(coef(glm.fit(A * V, Q, offset = family$linkfun(Q0), intercept = F, weights = task$weights, family = family)))
}

fit_object <- list(
coefficients = beta, lrnr_baseline = lrnr_baseline, covariates = covariates, family = family, formula = formula,
append_interaction_matrix = append_interaction_matrix, binary = binary, task_baseline = task_baseline
)
return(fit_object)
},
.predict = function(task) {
fit_object <- self$fit_object
append_interaction_matrix <- fit_object$append_interaction_matrix
binary <- fit_object$binary
beta <- fit_object$coefficients
lrnr_baseline <- fit_object$lrnr_baseline
covariates <- fit_object$covariates
family <- fit_object$family
formula <- fit_object$formula

trt <- self$params$interaction_variable
if (is.null(trt)) {
A <- rep(1, task$nrow)
} else {
A <- unlist(task$get_data(, trt))
}
V <- model.matrix(formula, task$data)
colnames(V) <- paste0("V", 1:ncol(V))


if (!append_interaction_matrix && binary) {
task_baseline <- task$next_in_chain(covariates = covariates)
Q0 <- lrnr_baseline$predict(task_baseline)
} else {
if (append_interaction_matrix) {
X0 <- cbind(task$X[, covariates, with = F], 0 * V)
} else {
X0 <- cbind(task$X[, covariates, with = F], 0)
}
column_names <- task$add_columns(X0)
task_baseline0 <- task$next_in_chain(covariates = colnames(X0), column_names = column_names)
Q0 <- lrnr_baseline$predict(task_baseline0)
}
Q0 <- as.vector(Q0)
Q1 <- as.vector(family$linkinv(family$linkfun(Q0) + V %*% beta))
Q <- as.vector(family$linkinv(family$linkfun(Q0) + A * V %*% beta))
if (self$params$return_matrix_predictions && binary) {
predictions <- cbind(Q0, Q1, Q)
colnames(predictions) <- c("A=0", "A=1", "A")
predictions <- sl3::pack_predictions(predictions)
} else {
predictions <- Q
}


return(predictions)
}
)
)
4 changes: 4 additions & 0 deletions R/Param_ATC.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Param_ATC <- R6Class(
private$.cf_likelihood_control <- CF_Likelihood$new(observed_likelihood, intervention_list_control)
private$.outcome_node <- outcome_node
private$.param_att <- Param_ATT$new(observed_likelihood, intervention_list_control, intervention_list_treatment, outcome_node)
private$.submodel <- private$.param_att$submodel
},
clever_covariates = function(tmle_task = NULL, fold_number = "full") {
att_cc <- self$param_att$clever_covariates(tmle_task, fold_number)
Expand Down Expand Up @@ -96,6 +97,9 @@ Param_ATC <- R6Class(
},
param_att = function() {
return(private$.param_att)
},
submodel = function(){
self$param_att$submodel
}
),
private = list(
Expand Down
4 changes: 2 additions & 2 deletions R/Param_ATE.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ Param_ATE <- R6Class(
EY1 <- self$observed_likelihood$get_likelihood(cf_task_treatment, self$outcome_node, fold_number)
EY0 <- self$observed_likelihood$get_likelihood(cf_task_control, self$outcome_node, fold_number)

psi <- mean(EY1 - EY0)
psi <- weighted.mean(EY1 - EY0, tmle_task$weights)

IC <- HA * (Y - EY) + (EY1 - EY0) - psi

IC <- IC * tmle_task$weights
result <- list(psi = psi, IC = IC)
return(result)
}
Expand Down
2 changes: 1 addition & 1 deletion R/Param_ATT.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Param_ATT <- R6Class(
inherit = Param_base,
public = list(
initialize = function(observed_likelihood, intervention_list_treatment, intervention_list_control, outcome_node = "Y") {
super$initialize(observed_likelihood, list(), outcome_node)
super$initialize(observed_likelihood, list(), outcome_node, submodel = list("A" = "logistic_switch", "Y" = "binomial_logit"))
private$.cf_likelihood_treatment <- CF_Likelihood$new(observed_likelihood, intervention_list_treatment)
private$.cf_likelihood_control <- CF_Likelihood$new(observed_likelihood, intervention_list_control)
},
Expand Down
46 changes: 44 additions & 2 deletions R/Param_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,23 @@ Param_base <- R6Class(
portable = TRUE,
class = TRUE,
public = list(
initialize = function(observed_likelihood, ..., outcome_node = "Y") {
initialize = function(observed_likelihood, ..., outcome_node = "Y", submodel = NULL) {
private$.observed_likelihood <- observed_likelihood
private$.outcome_node <- outcome_node
if(is.null(submodel)) { # Default submodel
submodel <- list("A" = get_submodel_spec("binomial_logit"), "Y" = get_submodel_spec("binomial_logit"), "default" = get_submodel_spec("binomial_logit"))
} else if (is.list(submodel)) { # Convert to submodel spec list
submodel_names <- names(submodel)

submodel <- lapply(submodel, get_submodel_spec) # For each node, convert to submodel spec list. #get_submodel_spec does nothing if item is already a list
names(submodel) <- submodel_names
} else {
submodel <- list("default" = get_submodel_spec(submodel))
}


private$.submodel <- submodel


if (!is.null(observed_likelihood$censoring_nodes[[outcome_node]])) {
if (!self$supports_outcome_censoring) {
Expand Down Expand Up @@ -50,6 +64,27 @@ Param_base <- R6Class(
},
print = function() {
cat(sprintf("%s: %s\n", class(self)[1], self$name))
},
supports_submodel = function(submodel_name, node) {
if (!(node %in% names(private$.submodel))) {
node <- "default"
}
return(submodel_name == private$.submodel[[node]]$name)
},
get_submodel_spec = function(update_node) {

if (!(update_node %in% names(self$submodel))) {
update_node <- "default"
}

spec <- self$submodel[[update_node]]
if(!is.list(spec)) {

spec <- get_submodel_spec(spec)
private$.submodel[[update_node]] <- spec
}

return(spec)
}
),
active = list(
Expand All @@ -71,14 +106,21 @@ Param_base <- R6Class(
},
targeted = function() {
return(private$.targeted)
},
submodel = function() {
return(private$.submodel)
},
weights = function() {
return(self$observed_likelihood$training_task$weights)
}
),
private = list(
.type = "undefined",
.observed_likelihood = NULL,
.outcome_node = NULL,
.targeted = TRUE,
.supports_outcome_censoring = FALSE
.supports_outcome_censoring = FALSE,
.submodel = NULL
)
)

Expand Down
Loading