From 773e428fb3328e6b336de2fb2d9b2cd81521582d Mon Sep 17 00:00:00 2001 From: rachaelvp Date: Mon, 6 Jun 2022 12:37:16 -0700 Subject: [PATCH] remake code --- R_code/06-sl3.R | 108 +++++++++++++++++++++++++++------------- R_code/08-tmle3mopttx.R | 22 ++++---- 2 files changed, 84 insertions(+), 46 deletions(-) diff --git a/R_code/06-sl3.R b/R_code/06-sl3.R index 17c5406..0110d89 100644 --- a/R_code/06-sl3.R +++ b/R_code/06-sl3.R @@ -106,13 +106,12 @@ sl <- Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new()) ## ----train-sl----------------------------------------------------------------- -# we will also set a stopwatch so we can measure how long this takes -start_time <- proc.time() +start_time <- proc.time() # start time set.seed(4197) sl_fit <- sl$train(task = task) -runtime_sl_fit <- proc.time() - start_time +runtime_sl_fit <- proc.time() - start_time # end time - start time = run time runtime_sl_fit @@ -121,11 +120,22 @@ sl_preds <- sl_fit$predict(task = task) head(sl_preds) -## ----lrnr-predictions--------------------------------------------------------- +## ----glm-predictions---------------------------------------------------------- glm_preds <- sl_fit$learner_fits$Lrnr_glm_TRUE$predict(task = task) head(glm_preds) +## ----glm-predictions-fullfit-------------------------------------------------- +# we can also access the candidate learner full fits directly and obtain +# the same "full fit" candidate predictions from there +# (we split this into two lines to avoid overflow) +stack_full_fits <- sl_fit$fit_object$full_fit$learner_fits$Stack$learner_fits +glm_preds_full_fit <- stack_full_fits$Lrnr_glm_TRUE$predict(task) + +# check that they are identical +identical(glm_preds, glm_preds_full_fit) + + ## ----predvobs----------------------------------------------------------------- # table of observed and predicted outcome values and arrange by observed values df_plot <- data.table( @@ -148,7 +158,7 @@ if (knitr::is_latex_output()) { scroll_box(width = "100%", height = "300px") } -## ----predobs-plot------------------------------------------------------------- +## ----predobs-plot, out.width = '60%', fig.asp = .62--------------------------- # melt the table so we can plot observed and predicted values df_plot$id <- seq(1:nrow(df_plot)) df_plot_melted <- melt( @@ -192,6 +202,16 @@ if (knitr::is_latex_output()) { } +## ----glm-predict-fold--------------------------------------------------------- +full_fit_preds <- sl_fit$fit_object$cv_fit$predict_fold( + task = task, fold_number = "full" +) +glm_full_fit_preds <- full_fit_preds$Lrnr_glm_TRUE + +# check that they are identical +identical(glm_preds, glm_full_fit_preds) + + ## ----cv-predictions-long------------------------------------------------------ ##### CV predictions "by hand" ##### # for each fold, i, we obtain validation set predictions: @@ -209,14 +229,20 @@ cv_preds_list <- lapply(seq_along(task$folds), function(i){ # get predicted outcomes for fold i's validation dataset, using candidates # trained to fold i's training dataset - v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict( - task = v_task + v_preds <- sl_fit$fit_object$cv_fit$predict_fold( + task = v_task, fold_number = i ) # note: v_preds is a matrix of candidate learner predictions, where the # number of rows is the number of observations in fold i's validation dataset # and the number of columns is the number of candidate learners (excluding # any that might have failed) + # an identical way to get v_preds, which is used when we calculate the + # cv risk by hand in a later part of this chapter: + # v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict( + # task = v_task + # ) + # we will also return the row indices for fold i's validation set, so we # can later reorder the CV predictions and make sure they are equal to what # we obtained above @@ -236,7 +262,7 @@ identical(cv_preds_option1, cv_preds_byhand_ordered) ## ----predictions-new-task, eval = FALSE--------------------------------------- ## prediction_task <- make_sl3_Task( -## data = new_data, # assuming we have some new data for predictions +## data = washb_data_new, # assuming we have some new data for predictions ## covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu", ## "momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec", ## "floor", "walls", "roof", "asset_wardrobe", "asset_table", @@ -332,8 +358,7 @@ identical( ## xlab("Learner") -## ----cvsl, eval=F------------------------------------------------------------- -## # we will also set a stopwatch so we can measure how long this takes +## ----cvsl, eval = FALSE------------------------------------------------------- ## start_time <- proc.time() ## ## set.seed(569) @@ -343,38 +368,51 @@ identical( ## runtime_cv_sl_fit -## ----cvsl-risk-summary, eval = FALSE------------------------------------------ -## cv_sl_fit$cv_risk[,c(1:3)] +## ----cvsl-save, eval = FALSE, echo = FALSE------------------------------------ +## library(here) +## save(cv_sl_fit, file=here("data", "fit_objects", "cv_sl_fit.Rdata"), compress=T) +## save(runtime_cv_sl_fit, file=here("data", "fit_objects", "runtime_cv_sl_fit.Rdata")) -## ----cvsl-risk-summary-handbook, echo = FALSE, eval=F------------------------- -## if (knitr::is_latex_output()) { -## cv_sl_fit$cv_risk[,c(1:3)] %>% -## kable(format = "latex") -## } else if (knitr::is_html_output()) { -## cv_sl_fit$cv_risk[,c(1:3)] %>% -## kable() %>% -## kableExtra:::kable_styling(fixed_thead = TRUE) %>% -## scroll_box(width = "100%", height = "300px") -## } +## ----cvsl-load, eval = TRUE, echo = FALSE------------------------------------- +library(here) +load(here("data", "fit_objects", "cv_sl_fit.Rdata")) +load(here("data", "fit_objects", "runtime_cv_sl_fit.Rdata")) +runtime_cv_sl_fit + + +## ----cvsl-risk-summary-------------------------------------------------------- +cv_sl_fit$cv_risk[,c(1:3)] + + +## ----cvsl-risk-summary-handbook, echo = FALSE--------------------------------- +if (knitr::is_latex_output()) { + cv_sl_fit$cv_risk[,c(1:3)] %>% + kable(format = "latex") +} else if (knitr::is_html_output()) { + cv_sl_fit$cv_risk[,c(1:3)] %>% + kable() %>% + kableExtra:::kable_styling(fixed_thead = TRUE) %>% + scroll_box(width = "100%", height = "300px") +} ## ----sl-revere-risk----------------------------------------------------------- -cv_risk_table_revere <- sl_fit$cv_risk( +cv_risk_w_sl_revere <- sl_fit$cv_risk( eval_fun = loss_squared_error, get_sl_revere_risk = TRUE ) ## ----sl-revere-risk-summary, eval = FALSE------------------------------------- -## cv_risk_table_revere[,c(1:3)] +## cv_risk_w_sl_revere[,c(1:3)] ## ----sl-revere-risk-handbook, eval = FALSE, echo = FALSE---------------------- ## if (knitr::is_latex_output()) { -## cv_risk_table_revere[,c(1:3)] %>% +## cv_risk_w_sl_revere[,c(1:3)] %>% ## kable(format = "latex") ## } else if (knitr::is_html_output()) { -## cv_risk_table_revere[,c(1:3)] %>% +## cv_risk_w_sl_revere[,c(1:3)] %>% ## kable() %>% ## kableExtra:::kable_styling(fixed_thead = TRUE) %>% ## scroll_box(width = "100%", height = "300px") @@ -409,13 +447,13 @@ sl_revere_risk_list <- lapply(seq_along(task$folds), function(i){ # get predicted outcomes for fold i's metalevel dataset, using the fitted # metalearner, cv_meta_fit - sl_revere_v_preds <- sl_fit$fit_object$cv_meta_fit$predict(task = v_meta_task) + sl_revere_v_preds <- sl_fit$fit_object$cv_meta_fit$predict(task=v_meta_task) # note: cv_meta_fit was trained on the metalevel dataset, which contains the # candidates' cv predictions and validation dataset outcomes across ALL folds, # so cv_meta_fit has already seen fold i's validation dataset outcomes. # calculate predictive performance for fold i for the SL - eval_function <- loss_squared_error # valid for estimation of conditional mean + eval_function <- loss_squared_error # valid for estimation of conditional mean # note: by evaluating the predictive performance of the SL using outcomes # that were already seen by the metalearner, this is not a cross-validated # measure of predictive performance for the SL. @@ -430,11 +468,11 @@ sl_revere_risk_byhand <- mean(unlist(sl_revere_risk_list)) sl_revere_risk_byhand # check that our calculation by hand equals what is output in cv_risk_table_revere -sl_revere_risk <- as.numeric(cv_risk_table_revere[learner == "SuperLearner", "MSE"]) -identical(sl_revere_risk_byhand, sl_revere_risk) +sl_revere_risk <- as.numeric(cv_risk_w_sl_revere[learner=="SuperLearner","MSE"]) +identical(sl_revere_risk, sl_revere_risk_byhand) -## ----make-Lrnr_cv_selector---------------------------------------------------- +## ----make-Lrnr-cv-selector---------------------------------------------------- cv_selector <- Lrnr_cv_selector$new(eval_function = loss_squared_error) @@ -517,7 +555,7 @@ task$data[some_rows_with_missingness, colSums(is.na(task$data)) -## ----fruit-------------------------------------------------------------------- +## ----kitty-------------------------------------------------------------------- cats <- c("calico", "tabby", "cow", "ragdoll", "mancoon", "dwarf", "calico") cats <- factor(cats) cats_onehot <- factor_to_indicators(cats) @@ -670,12 +708,12 @@ if (knitr::is_latex_output()) { } -## ----varimp-plot, out.width = "100%"------------------------------------------ +## ----varimp-plot, out.width = '60%', fig.asp = .62---------------------------- # plot variable importance importance_plot(x = washb_varimp) -## ----cde_using_locscale, eval = FALSE----------------------------------------- +## ----cde-using-locscale, eval = FALSE----------------------------------------- ## # semiparametric density estimator with homoscedastic errors (HOSE) ## hose_hal_lrnr <- Lrnr_density_semiparametric$new( ## mean_learner = Lrnr_hal9001$new() @@ -693,7 +731,7 @@ importance_plot(x = washb_varimp) ## ) -## ----cde_using_pooledhaz, eval = FALSE---------------------------------------- +## ----cde-using-pooledhaz, eval = FALSE---------------------------------------- ## # learners used for conditional densities for (g_n) ## haldensify_lrnr <- Lrnr_haldensify$new( ## n_bins = c(5, 10) diff --git a/R_code/08-tmle3mopttx.R b/R_code/08-tmle3mopttx.R index fa96940..e0b38d2 100644 --- a/R_code/08-tmle3mopttx.R +++ b/R_code/08-tmle3mopttx.R @@ -243,19 +243,19 @@ delta_learner <- Lrnr_sl$new( learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner, delta_Y=delta_learner) -## ----spec_init_missingness---------------------------------------------------- -# initialize a tmle specification -tmle_spec_cat_miss <- tmle3_mopttx_blip_revere( - V = c("W1", "W2", "W3", "W4"), type = "blip2", - learners = learner_list, maximize = TRUE, complex = TRUE, - realistic = FALSE -) +## ----spec_init_missingness, eval = FALSE-------------------------------------- +## # initialize a tmle specification +## tmle_spec_cat_miss <- tmle3_mopttx_blip_revere( +## V = c("W1", "W2", "W3", "W4"), type = "blip2", +## learners = learner_list, maximize = TRUE, complex = TRUE, +## realistic = FALSE +## ) -## ----fit_tmle_auto2, eval=T--------------------------------------------------- -# fit the TML estimator -fit_cat_miss <- tmle3(tmle_spec_cat_miss, data_missing, node_list, learner_list) -fit_cat_miss +## ----fit_tmle_auto2, eval = FALSE--------------------------------------------- +## # fit the TML estimator +## fit_cat_miss <- tmle3(tmle_spec_cat_miss, data_missing, node_list, learner_list) +## fit_cat_miss ## ----spec_init_Qlearning2, eval=FALSE-----------------------------------------