diff --git a/R_code/02-roadmap.R b/R_code/02-roadmap.R index cca11ce..550f53c 100644 --- a/R_code/02-roadmap.R +++ b/R_code/02-roadmap.R @@ -1,4 +1,4 @@ -## ----simple-DAG--------------------------------------------------------------- +## ----simple-DAG, out.width = "60%"-------------------------------------------- library(dagitty) library(ggdag) diff --git a/R_code/03-tlverse.R b/R_code/03-tlverse.R index 3e3e67b..c75e15d 100644 --- a/R_code/03-tlverse.R +++ b/R_code/03-tlverse.R @@ -4,5 +4,5 @@ ## ----renviron-example, results="asis", eval=FALSE----------------------------- -## GITHUB_PAT=yourPAT -## +## GITHUB_PAT <- yourPAT + diff --git a/R_code/04-data.R b/R_code/04-data.R index e5d963f..94326db 100644 --- a/R_code/04-data.R +++ b/R_code/04-data.R @@ -18,38 +18,3 @@ if (knitr::is_latex_output()) { skim(dat) } - -## ----load_ist_data_intro, message=FALSE, warning=FALSE------------------------ -# read in data -ist <- read_csv( - paste0( - "https://raw.githubusercontent.com/tlverse/tlverse-handbook/master/", - "data/ist_sample.csv" - ) -) - - -## ----skim_ist_data, results="asis", echo=FALSE-------------------------------- -if (knitr::is_latex_output()) { - knitr::kable(skim_no_sparks(ist), format = "latex") -} else { - skim(ist) -} - - -## ----load_nhefs_data_intro---------------------------------------------------- -# read in data -nhefs_data <- read_csv( - paste0( - "https://raw.githubusercontent.com/tlverse/tlverse-handbook/master/", - "data/NHEFS.csv" - ) -) - - -## ----skim_nhefs_data, results="asis", echo=FALSE------------------------------ -if (knitr::is_latex_output()) { - knitr::kable(skim_no_sparks(nhefs_data), format = "latex") -} else { - skim(nhefs_data) -} diff --git a/R_code/05-origami.R b/R_code/05-origami.R index e94c0cf..8e69d27 100644 --- a/R_code/05-origami.R +++ b/R_code/05-origami.R @@ -73,7 +73,7 @@ t <- length(AP) ## ---- fig.cap="Rolling origin CV", results="asis", echo=FALSE----------------- -knitr::include_graphics(path = "img/image/rolling_origin.png") +knitr::include_graphics(path = "img/png/rolling_origin.png") ## ----rolling_origin----------------------------------------------------------- @@ -86,7 +86,7 @@ folds[[2]] ## ---- fig.cap="Rolling window CV", results="asis", echo=FALSE----------------- -knitr::include_graphics(path = "img/image/rolling_window.png") +knitr::include_graphics(path = "img/png/rolling_window.png") ## ----rolling_window----------------------------------------------------------- @@ -99,11 +99,11 @@ folds[[2]] ## ---- fig.cap="Rolling origin V-fold CV", results="asis", echo=FALSE---------- -knitr::include_graphics(path = "img/image/rolling_origin_v_fold.png") +knitr::include_graphics(path = "img/png/rolling_origin_v_fold.png") ## ---- fig.cap="Rolling window V-fold CV", results="asis", echo=FALSE---------- -knitr::include_graphics(path = "img/image/rolling_window_v_fold.png") +knitr::include_graphics(path = "img/png/rolling_window_v_fold.png") ## ----setup_ex----------------------------------------------------------------- @@ -227,7 +227,8 @@ cv_rf <- function(fold, data, reg_form) { # now, let's cross-validate... folds <- make_folds(washb_data) cvrf_results <- cross_validate( - cv_fun = cv_rf, folds = folds, data = washb_data, reg_form = "whz ~ .", + cv_fun = cv_rf, folds = folds, + data = washb_data, reg_form = "whz ~ .", use_future = FALSE ) mean(cvrf_results$SE) @@ -300,3 +301,4 @@ mses <- cross_validate( ) mses$mse colMeans(mses$mse[, c("arima", "arima2")]) + diff --git a/R_code/06-sl3.R b/R_code/06-sl3.R index 78071be..8f7548a 100644 --- a/R_code/06-sl3.R +++ b/R_code/06-sl3.R @@ -1,5 +1,9 @@ ## ----cv_fig, fig.show="hold", echo = FALSE------------------------------------ -knitr::include_graphics("img/misc/SLKaiserNew.pdf") +knitr::include_graphics("img/png/vs.png") + + +## ----cv_sl_alg, fig.show="hold", echo = FALSE--------------------------------- +knitr::include_graphics("img/png/SLKaiserNew.png") ## ----setup-------------------------------------------------------------------- @@ -21,18 +25,10 @@ washb_data <- fread( ), stringsAsFactors = TRUE ) - - -## ----sl3_washb_example_table1, echo=FALSE------------------------------------- -if (knitr::is_latex_output()) { - head(washb_data) %>% - kable(format = "latex") -} else if (knitr::is_html_output()) { - head(washb_data) %>% - kable() %>% - kable_styling(fixed_thead = TRUE) %>% - scroll_box(width = "100%", height = "300px") -} +head(washb_data) %>% + kable() %>% + kableExtra:::kable_styling(fixed_thead = T) %>% + scroll_box(width = "100%", height = "300px") ## ----task, warning=TRUE------------------------------------------------------- @@ -59,8 +55,7 @@ head(washb_task$folds[[1]]$training_set) # row indexes for fold 1 training head(washb_task$folds[[1]]$validation_set) # row indexes for fold 1 validation any( - washb_task$folds[[1]]$training_set %in% - washb_task$folds[[1]]$validation_set + washb_task$folds[[1]]$training_set %in% washb_task$folds[[1]]$validation_set ) @@ -100,7 +95,7 @@ lrn_interaction <- make_learner(Lrnr_define_interactions, interactions) ## ----interaction-pipe--------------------------------------------------------- -# we already instantiated a linear model learner above, no need to do it again +# we already instantiated a linear model learner, no need to do that again lrn_glm_interaction <- make_learner(Pipeline, lrn_interaction, lrn_glm) lrn_glm_interaction @@ -122,7 +117,6 @@ lrn_bayesglm <- Lrnr_pkg_SuperLearner$new("SL.bayesglm") ## do.call(Lrnr_svm$new, as.list(tuning_params)) ## }) - ## ----extra-lrnr-mindblown-xgboost--------------------------------------------- grid_params <- list( max_depth = c(2, 4, 6), @@ -138,7 +132,7 @@ xgb_learners <- apply(grid, MARGIN = 1, function(tuning_params) { xgb_learners -## ----carotene, eval=FALSE----------------------------------------------------- +## ----carotene, eval = FALSE--------------------------------------------------- ## # Unlike xgboost, I have no idea how to tune a neural net or BART machine, so ## # I let caret take the reins ## lrnr_caret_nnet <- make_learner(Lrnr_caret, algorithm = "nnet") @@ -158,7 +152,9 @@ stack ## ----alt-stack---------------------------------------------------------------- # named vector of learners first -learners <- c(lrn_glm, lrn_polspline, lrn_enet.5, lrn_ridge, lrn_lasso, xgb_50) +learners <- c( + lrn_glm, lrn_polspline, lrn_enet.5, lrn_ridge, lrn_lasso, xgb_50 +) names(learners) <- c( "glm", "polspline", "enet.5", "ridge", "lasso", "xgboost50" ) @@ -245,17 +241,18 @@ head(sl_preds) ## ---- plot-predvobs-woohoo, eval=FALSE---------------------------------------- +## ## # df_plot <- data.frame(Observed = washb_data[["whz"]], Predicted = sl_preds, ## # count = seq(1:nrow(washb_data)) -## +## ## # df_plot_melted <- melt(df_plot, id.vars = "count", ## # measure.vars = c("Observed", "Predicted")) -## +## ## # ggplot(df_plot_melted, aes(value, count, color = variable)) + geom_point() ## ---- sl-summary-------------------------------------------------------------- -sl_fit_summary <- sl_fit$print() +sl_fit$cv_risk(loss_fun = loss_squared_error) ## ----CVsl--------------------------------------------------------------------- @@ -268,37 +265,21 @@ washb_task_new <- make_sl3_Task( CVsl <- CV_lrnr_sl( lrnr_sl = sl_fit, task = washb_task_new, loss_fun = loss_squared_error ) - - -## ----CVsl_table--------------------------------------------------------------- -if (knitr::is_latex_output()) { - CVsl %>% - kable(format = "latex") -} else if (knitr::is_html_output()) { - CVsl %>% - kable() %>% - kable_styling(fixed_thead = TRUE) %>% - scroll_box(width = "100%", height = "300px") -} +CVsl %>% + kable(digits = 4) %>% + kableExtra:::kable_styling(fixed_thead = T) %>% + scroll_box(width = "100%", height = "300px") ## ----varimp------------------------------------------------------------------- washb_varimp <- importance(sl_fit, loss = loss_squared_error, type = "permute") +washb_varimp %>% + kable(digits = 4) %>% + kableExtra:::kable_styling(fixed_thead = TRUE) %>% + scroll_box(width = "100%", height = "300px") -## ----varimp_table------------------------------------------------------------- -if (knitr::is_latex_output()) { - washb_varimp %>% - kable(format = "latex") -} else if (knitr::is_html_output()) { - washb_varimp %>% - kable() %>% - kable_styling(fixed_thead = TRUE) %>% - scroll_box(width = "100%", height = "300px") -} - - -## ----varimp-plot-------------------------------------------------------------- +## ----varimp-plot, out.width = "100%"------------------------------------------ # plot variable importance importance_plot( washb_varimp, @@ -329,56 +310,37 @@ if (knitr::is_latex_output()) { } -## ----ex-setup2---------------------------------------------------------------- -ist_data <- paste0( - "https://raw.githubusercontent.com/tlverse/", - "tlverse-handbook/master/data/ist_sample.csv" -) %>% fread() - -# number 3 help -ist_task_CVsl <- make_sl3_Task( - data = ist_data, - outcome = "DRSISC", - covariates = colnames(ist_data)[-which(names(ist_data) == "DRSISC")], - drop_missing_outcome = TRUE, - folds = origami::make_folds( - n = sum(!is.na(ist_data$DRSISC)), - fold_fun = folds_vfold, - V = 5 - ) -) - - ## ----ex-key, eval=FALSE------------------------------------------------------- ## db_data <- url( ## "https://raw.githubusercontent.com/benkeser/sllecture/master/chspred.csv" ## ) ## chspred <- read_csv(file = db_data, col_names = TRUE) -## +## data.table::setDT(chspred) +## ## # make task ## chspred_task <- make_sl3_Task( ## data = chspred, -## covariates = head(colnames(chspred), -1), +## covariates = colnames(chspred)[-1], ## outcome = "mi" ## ) -## +## ## # make learners ## glm_learner <- Lrnr_glm$new() ## lasso_learner <- Lrnr_glmnet$new(alpha = 1) ## ridge_learner <- Lrnr_glmnet$new(alpha = 0) ## enet_learner <- Lrnr_glmnet$new(alpha = 0.5) -## # curated_glm_learner uses formula = "mi ~ smoke + beta + waist" -## curated_glm_learner <- Lrnr_glm_fast$new(covariates = c("smoke, beta, waist")) +## # curated_glm_learner uses formula = "mi ~ smoke + beta" +## curated_glm_learner <- Lrnr_glm_fast$new(covariates = c("smoke", "beta")) ## mean_learner <- Lrnr_mean$new() # That is one mean learner! ## glm_fast_learner <- Lrnr_glm_fast$new() ## ranger_learner <- Lrnr_ranger$new() ## svm_learner <- Lrnr_svm$new() ## xgb_learner <- Lrnr_xgboost$new() -## +## ## # screening ## screen_cor <- make_learner(Lrnr_screener_correlation) ## glm_pipeline <- make_learner(Pipeline, screen_cor, glm_learner) -## +## ## # stack learners together ## stack <- make_learner( ## Stack, @@ -387,32 +349,29 @@ ist_task_CVsl <- make_sl3_Task( ## curated_glm_learner, mean_learner, glm_fast_learner, ## ranger_learner, svm_learner, xgb_learner ## ) -## +## ## # make and train SL ## sl <- Lrnr_sl$new( ## learners = stack ## ) ## sl_fit <- sl$train(chspred_task) -## sl_fit$print() -## -## CVsl <- CV_lrnr_sl(sl_fit, chspred_task, loss_loglik_binomial) +## sl_fit$cv_risk(loss_squared_error) +## +## CVsl <- CV_lrnr_sl(sl_fit, chspred_task, loss_squared_error) ## CVsl -## -## varimp <- importance(sl_fit, type = "permute") -## varimp %>% -## importance_plot( -## main = "sl3 Variable Importance for Myocardial Infarction Prediction" -## ) +## +## varimp <- importance(sl_fit) +## importance_plot(varimp) ## ----ex2-key, eval=FALSE------------------------------------------------------ ## library(ROCR) # for AUC calculation -## +## ## ist_data <- paste0( ## "https://raw.githubusercontent.com/tlverse/", ## "tlverse-handbook/master/data/ist_sample.csv" ## ) %>% fread() -## +## ## # stack ## ist_task <- make_sl3_Task( ## data = ist_data, @@ -420,7 +379,7 @@ ist_task_CVsl <- make_sl3_Task( ## covariates = colnames(ist_data)[-which(names(ist_data) == "DRSISC")], ## drop_missing_outcome = TRUE ## ) -## +## ## # learner library ## lrn_glm <- Lrnr_glm$new() ## lrn_lasso <- Lrnr_glmnet$new(alpha = 1) @@ -435,9 +394,8 @@ ist_task_CVsl <- make_sl3_Task( ## eta = c(0.01, 0.15, 0.3) ## ) ## grid <- expand.grid(grid_params, KEEP.OUT.ATTRS = FALSE) -## params_default <- list(nthread = getOption("sl.cores.learners", 1)) ## xgb_learners <- apply(grid, MARGIN = 1, function(params_tune) { -## do.call(Lrnr_xgboost$new, c(params_default, as.list(params_tune))) +## do.call(Lrnr_xgboost$new, as.list(params_tune)) ## }) ## learners <- unlist(list( ## xgb_learners, lrn_ridge, lrn_mean, lrn_lasso, @@ -445,17 +403,17 @@ ist_task_CVsl <- make_sl3_Task( ## ), ## recursive = TRUE ## ) -## +## ## # SL ## sl <- Lrnr_sl$new(learners) ## sl_fit <- sl$train(ist_task) -## +## ## # AUC ## preds <- sl_fit$predict() ## obs <- c(na.omit(ist_data$DRSISC)) ## AUC <- performance(prediction(sl_preds, obs), measure = "auc")@y.values[[1]] ## plot(performance(prediction(sl_preds, obs), "tpr", "fpr")) -## +## ## # CVsl ## ist_task_CVsl <- make_sl3_Task( ## data = ist_data, @@ -470,7 +428,7 @@ ist_task_CVsl <- make_sl3_Task( ## ) ## CVsl <- CV_lrnr_sl(sl_fit, ist_task_CVsl, loss_loglik_binomial) ## CVsl -## +## ## # sl3 variable importance plot ## ist_varimp <- importance(sl_fit, type = "permute") ## ist_varimp %>% @@ -481,3 +439,4 @@ ist_task_CVsl <- make_sl3_Task( ## ----ex3-key, eval=FALSE------------------------------------------------------ ## # TODO + diff --git a/R_code/07-tmle3.R b/R_code/07-tmle3.R index 3b75622..1e71e93 100644 --- a/R_code/07-tmle3.R +++ b/R_code/07-tmle3.R @@ -1,13 +1,13 @@ ## ----tmle_fig1, results="asis", echo = FALSE---------------------------------- -knitr::include_graphics("img/misc/tmle_sim/schematic_1_truedgd.png") +knitr::include_graphics("img/png/schematic_1_truedgd.png") ## ----tmle_fig2, results="asis", echo = FALSE---------------------------------- -knitr::include_graphics("img/misc/tmle_sim/schematic_2b_sllik.png") +knitr::include_graphics("img/png/schematic_2b_sllik.png") ## ----tmle_fig3, results="asis", echo = FALSE---------------------------------- -knitr::include_graphics("img/misc/tmle_sim/schematic_3_effects.png") +knitr::include_graphics("img/png/schematic_3_effects.png") ## ----tmle3-load-data---------------------------------------------------------- @@ -175,12 +175,3 @@ metalearner <- make_learner( learner_function = metalearner_logistic_binomial ) - -## ----tmle3-ex2---------------------------------------------------------------- -ist_data <- fread( - paste0( - "https://raw.githubusercontent.com/tlverse/deming2019-workshop/", - "master/data/ist_sample.csv" - ) -) - diff --git a/R_code/08-tmle3mopttx.R b/R_code/08-tmle3mopttx.R index 2eff64c..3398e18 100644 --- a/R_code/08-tmle3mopttx.R +++ b/R_code/08-tmle3mopttx.R @@ -1,5 +1,5 @@ ## ---- fig.cap="Dynamic Treatment Regime in a Clinical Setting", results="asis", echo=FALSE---- -knitr::include_graphics(path = "img/image/DynamicA_Illustration.png") +knitr::include_graphics(path = "img/png/DynamicA_Illustration.png") ## ----setup-mopttx------------------------------------------------------------- @@ -7,6 +7,9 @@ library(data.table) library(sl3) library(tmle3) library(tmle3mopttx) +library(devtools) + +set.seed(111) ## ----load sim_bin_data-------------------------------------------------------- @@ -27,31 +30,27 @@ node_list <- list( # Define sl3 library and metalearners: lrn_xgboost_50 <- Lrnr_xgboost$new(nrounds = 50) lrn_xgboost_100 <- Lrnr_xgboost$new(nrounds = 100) -lrn_xgboost_300 <- Lrnr_xgboost$new(nrounds = 300) +lrn_xgboost_500 <- Lrnr_xgboost$new(nrounds = 500) + lrn_mean <- Lrnr_mean$new() lrn_glm <- Lrnr_glm_fast$new() +lrn_lasso <- Lrnr_glmnet$new() ## Define the Q learner: Q_learner <- Lrnr_sl$new( - learners = list( - lrn_xgboost_50, lrn_xgboost_100, - lrn_xgboost_300, lrn_mean, lrn_glm - ), + learners = list(lrn_lasso, lrn_mean, lrn_glm), metalearner = Lrnr_nnls$new() ) ## Define the g learner: g_learner <- Lrnr_sl$new( - learners = list(lrn_xgboost_100, lrn_glm), + learners = list(lrn_lasso, lrn_glm), metalearner = Lrnr_nnls$new() ) ## Define the B learner: b_learner <- Lrnr_sl$new( - learners = list( - lrn_xgboost_50, lrn_xgboost_100, - lrn_xgboost_300, lrn_mean, lrn_glm - ), + learners = list(lrn_lasso,lrn_mean, lrn_glm), metalearner = Lrnr_nnls$new() ) @@ -66,7 +65,8 @@ learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner) tmle_spec <- tmle3_mopttx_blip_revere( V = c("W1", "W2", "W3"), type = "blip1", learners = learner_list, - maximize = TRUE, complex = TRUE, realistic = FALSE + maximize = TRUE, complex = TRUE, + realistic = FALSE, resource = 1 ) @@ -76,6 +76,46 @@ fit <- tmle3(tmle_spec, data, node_list, learner_list) fit +## ----mopttx_spec_init_complex_resource---------------------------------------- +# initialize a tmle specification +tmle_spec_resource <- tmle3_mopttx_blip_revere( + V = c("W1", "W2", "W3"), type = "blip1", + learners = learner_list, + maximize = TRUE, complex = TRUE, + realistic = FALSE, resource = 0.90 +) + + +## ----mopttx_fit_tmle_auto_blip_revere_complex_resource, eval=T---------------- +# fit the TML estimator +fit_resource <- tmle3(tmle_spec_resource, data, node_list, learner_list) +fit_resource + + +## ----mopttx_compare_resource-------------------------------------------------- +# Number of individuals getting treatment (no resource constraint): +table(tmle_spec$return_rule) + +# Number of individuals getting treatment (resource constraint): +table(tmle_spec_resource$return_rule) + + +## ----mopttx_spec_init_complex_V_empty----------------------------------------- +# initialize a tmle specification +tmle_spec_V_empty <- tmle3_mopttx_blip_revere( + type = "blip1", + learners = learner_list, + maximize = TRUE, complex = TRUE, + realistic = FALSE, resource = 0.90 +) + + +## ----mopttx_fit_tmle_auto_blip_revere_complex_V_empty, eval=T----------------- +# fit the TML estimator +fit_V_empty <- tmle3(tmle_spec_V_empty, data, node_list, learner_list) +fit_V_empty + + ## ----load sim_cat_data-------------------------------------------------------- data("data_cat_realistic") @@ -90,13 +130,22 @@ node_list <- list( ) +## ----data_cats-mopttx--------------------------------------------------------- +# organize data and nodes for tmle3 +table(data$A) + + ## ----sl3_lrnrs-mopttx--------------------------------------------------------- +# Initialize few of the learners: +lrn_xgboost_50 <- Lrnr_xgboost$new(nrounds = 50) +lrn_xgboost_100 <- Lrnr_xgboost$new(nrounds = 100) +lrn_xgboost_500 <- Lrnr_xgboost$new(nrounds = 500) +lrn_mean <- Lrnr_mean$new() +lrn_glm <- Lrnr_glm_fast$new() + ## Define the Q learner, which is just a regular learner: Q_learner <- Lrnr_sl$new( - learners = list( - lrn_xgboost_50, lrn_xgboost_100, lrn_xgboost_300, - lrn_mean, lrn_glm - ), + learners = list(lrn_xgboost_100, lrn_mean, lrn_glm), metalearner = Lrnr_nnls$new() ) @@ -104,20 +153,12 @@ Q_learner <- Lrnr_sl$new( # specify the appropriate loss of the multinomial learner: mn_metalearner <- make_learner(Lrnr_solnp, loss_function = loss_loglik_multinomial, - learner_function = - metalearner_linear_multinomial -) -g_learner <- make_learner( - Lrnr_sl, - list(lrn_xgboost_100, lrn_xgboost_300, lrn_mean), - mn_metalearner + learner_function = metalearner_linear_multinomial ) +g_learner <- make_learner(Lrnr_sl, list(lrn_xgboost_100, lrn_xgboost_500, lrn_mean), mn_metalearner) # Define the Blip learner, which is a multivariate learner: -learners <- list( - lrn_xgboost_50, lrn_xgboost_100, lrn_xgboost_300, lrn_mean, - lrn_glm -) +learners <- list(lrn_xgboost_50, lrn_xgboost_100, lrn_xgboost_500, lrn_mean, lrn_glm) b_learner <- create_mv_learners(learners = learners) @@ -133,7 +174,7 @@ learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner) ## ----spec_init---------------------------------------------------------------- # initialize a tmle specification -tmle_spec <- tmle3_mopttx_blip_revere( +tmle_spec_cat <- tmle3_mopttx_blip_revere( V = c("W1", "W2", "W3", "W4"), type = "blip2", learners = learner_list, maximize = TRUE, complex = TRUE, realistic = FALSE @@ -142,13 +183,16 @@ tmle_spec <- tmle3_mopttx_blip_revere( ## ----fit_tmle_auto------------------------------------------------------------ # fit the TML estimator -fit <- tmle3(tmle_spec, data, node_list, learner_list) -fit +fit_cat <- tmle3(tmle_spec_cat, data, node_list, learner_list) +fit_cat + +# How many individuals got assigned each treatment? +table(tmle_spec_cat$return_rule) ## ----mopttx_spec_init_noncomplex---------------------------------------------- # initialize a tmle specification -tmle_spec <- tmle3_mopttx_blip_revere( +tmle_spec_cat_simple <- tmle3_mopttx_blip_revere( V = c("W4", "W3", "W2", "W1"), type = "blip2", learners = learner_list, maximize = TRUE, complex = FALSE, realistic = FALSE @@ -157,13 +201,13 @@ tmle_spec <- tmle3_mopttx_blip_revere( ## ----mopttx_fit_tmle_auto_blip_revere_noncomplex------------------------------ # fit the TML estimator -fit <- tmle3(tmle_spec, data, node_list, learner_list) -fit +fit_cat_simple <- tmle3(tmle_spec_cat_simple, data, node_list, learner_list) +fit_cat_simple ## ----mopttx_spec_init_realistic----------------------------------------------- # initialize a tmle specification -tmle_spec <- tmle3_mopttx_blip_revere( +tmle_spec_cat_realistic <- tmle3_mopttx_blip_revere( V = c("W4", "W3", "W2", "W1"), type = "blip2", learners = learner_list, maximize = TRUE, complex = TRUE, realistic = TRUE @@ -172,35 +216,68 @@ tmle_spec <- tmle3_mopttx_blip_revere( ## ----mopttx_fit_tmle_auto_blip_revere_realistic------------------------------- # fit the TML estimator -fit <- tmle3(tmle_spec, data, node_list, learner_list) -fit +fit_cat_realistic <- tmle3(tmle_spec_cat_realistic, data, node_list, learner_list) +fit_cat_realistic # How many individuals got assigned each treatment? -table(tmle_spec$return_rule) +table(tmle_spec_cat_realistic$return_rule) + + +## ----data_nodes-add-missigness-mopttx----------------------------------------- +data_missing <- data_cat_realistic + +#Add some random missingless: +rr <- sample(nrow(data_missing), 100, replace = FALSE) +data_missing[rr,"Y"]<-NA + +summary(data_missing$Y) + + +## ----sl3_lrnrs-add-mopttx----------------------------------------------------- +delta_learner <- Lrnr_sl$new( + learners = list(lrn_mean, lrn_glm), + metalearner = Lrnr_nnls$new() +) + +# specify outcome and treatment regressions and create learner list +learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner, delta_Y=delta_learner) + + +## ----spec_init_missingness---------------------------------------------------- +# initialize a tmle specification +tmle_spec_cat_miss <- tmle3_mopttx_blip_revere( + V = c("W1", "W2", "W3", "W4"), type = "blip2", + learners = learner_list, maximize = TRUE, complex = TRUE, + realistic = FALSE +) + + +## ----fit_tmle_auto2, eval=T--------------------------------------------------- +# fit the TML estimator +fit_cat_miss <- tmle3(tmle_spec_cat_miss, data_missing, node_list, learner_list) +fit_cat_miss ## ----spec_init_Qlearning2, eval=FALSE----------------------------------------- ## # initialize a tmle specification ## tmle_spec_Q <- tmle3_mopttx_Q(maximize = TRUE) -## +## ## # Define data: ## tmle_task <- tmle_spec_Q$make_tmle_task(data, node_list) -## +## ## # Define likelihood: ## initial_likelihood <- tmle_spec_Q$make_initial_likelihood( ## tmle_task, ## learner_list ## ) -## +## ## # Estimate the parameter: ## Q_learning(tmle_spec_Q, initial_likelihood, tmle_task)[1] ## ----data_vim-nodes-mopttx---------------------------------------------------- # bin baseline covariates to 3 categories: -data$W1 <- ifelse(data$W1 < quantile(data$W1)[2], 1, - ifelse(data$W1 < quantile(data$W1)[3], 2, 3) -) +data$W1<-ifelse(data$W1