Skip to content

Commit

Permalink
remake code
Browse files Browse the repository at this point in the history
  • Loading branch information
rachaelvp committed Jun 6, 2022
1 parent a84957b commit 773e428
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 46 deletions.
108 changes: 73 additions & 35 deletions R_code/06-sl3.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,12 @@ sl <- Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new())


## ----train-sl-----------------------------------------------------------------
# we will also set a stopwatch so we can measure how long this takes
start_time <- proc.time()
start_time <- proc.time() # start time

set.seed(4197)
sl_fit <- sl$train(task = task)

runtime_sl_fit <- proc.time() - start_time
runtime_sl_fit <- proc.time() - start_time # end time - start time = run time
runtime_sl_fit


Expand All @@ -121,11 +120,22 @@ sl_preds <- sl_fit$predict(task = task)
head(sl_preds)


## ----lrnr-predictions---------------------------------------------------------
## ----glm-predictions----------------------------------------------------------
glm_preds <- sl_fit$learner_fits$Lrnr_glm_TRUE$predict(task = task)
head(glm_preds)


## ----glm-predictions-fullfit--------------------------------------------------
# we can also access the candidate learner full fits directly and obtain
# the same "full fit" candidate predictions from there
# (we split this into two lines to avoid overflow)
stack_full_fits <- sl_fit$fit_object$full_fit$learner_fits$Stack$learner_fits
glm_preds_full_fit <- stack_full_fits$Lrnr_glm_TRUE$predict(task)

# check that they are identical
identical(glm_preds, glm_preds_full_fit)


## ----predvobs-----------------------------------------------------------------
# table of observed and predicted outcome values and arrange by observed values
df_plot <- data.table(
Expand All @@ -148,7 +158,7 @@ if (knitr::is_latex_output()) {
scroll_box(width = "100%", height = "300px")
}

## ----predobs-plot-------------------------------------------------------------
## ----predobs-plot, out.width = '60%', fig.asp = .62---------------------------
# melt the table so we can plot observed and predicted values
df_plot$id <- seq(1:nrow(df_plot))
df_plot_melted <- melt(
Expand Down Expand Up @@ -192,6 +202,16 @@ if (knitr::is_latex_output()) {
}


## ----glm-predict-fold---------------------------------------------------------
full_fit_preds <- sl_fit$fit_object$cv_fit$predict_fold(
task = task, fold_number = "full"
)
glm_full_fit_preds <- full_fit_preds$Lrnr_glm_TRUE

# check that they are identical
identical(glm_preds, glm_full_fit_preds)


## ----cv-predictions-long------------------------------------------------------
##### CV predictions "by hand" #####
# for each fold, i, we obtain validation set predictions:
Expand All @@ -209,14 +229,20 @@ cv_preds_list <- lapply(seq_along(task$folds), function(i){

# get predicted outcomes for fold i's validation dataset, using candidates
# trained to fold i's training dataset
v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict(
task = v_task
v_preds <- sl_fit$fit_object$cv_fit$predict_fold(
task = v_task, fold_number = i
)
# note: v_preds is a matrix of candidate learner predictions, where the
# number of rows is the number of observations in fold i's validation dataset
# and the number of columns is the number of candidate learners (excluding
# any that might have failed)

# an identical way to get v_preds, which is used when we calculate the
# cv risk by hand in a later part of this chapter:
# v_preds <- sl_fit$fit_object$cv_fit$fit_object$fold_fits[[i]]$predict(
# task = v_task
# )

# we will also return the row indices for fold i's validation set, so we
# can later reorder the CV predictions and make sure they are equal to what
# we obtained above
Expand All @@ -236,7 +262,7 @@ identical(cv_preds_option1, cv_preds_byhand_ordered)

## ----predictions-new-task, eval = FALSE---------------------------------------
## prediction_task <- make_sl3_Task(
## data = new_data, # assuming we have some new data for predictions
## data = washb_data_new, # assuming we have some new data for predictions
## covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
## "momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
## "floor", "walls", "roof", "asset_wardrobe", "asset_table",
Expand Down Expand Up @@ -332,8 +358,7 @@ identical(
## xlab("Learner")


## ----cvsl, eval=F-------------------------------------------------------------
## # we will also set a stopwatch so we can measure how long this takes
## ----cvsl, eval = FALSE-------------------------------------------------------
## start_time <- proc.time()
##
## set.seed(569)
Expand All @@ -343,38 +368,51 @@ identical(
## runtime_cv_sl_fit


## ----cvsl-risk-summary, eval = FALSE------------------------------------------
## cv_sl_fit$cv_risk[,c(1:3)]
## ----cvsl-save, eval = FALSE, echo = FALSE------------------------------------
## library(here)
## save(cv_sl_fit, file=here("data", "fit_objects", "cv_sl_fit.Rdata"), compress=T)
## save(runtime_cv_sl_fit, file=here("data", "fit_objects", "runtime_cv_sl_fit.Rdata"))


## ----cvsl-risk-summary-handbook, echo = FALSE, eval=F-------------------------
## if (knitr::is_latex_output()) {
## cv_sl_fit$cv_risk[,c(1:3)] %>%
## kable(format = "latex")
## } else if (knitr::is_html_output()) {
## cv_sl_fit$cv_risk[,c(1:3)] %>%
## kable() %>%
## kableExtra:::kable_styling(fixed_thead = TRUE) %>%
## scroll_box(width = "100%", height = "300px")
## }
## ----cvsl-load, eval = TRUE, echo = FALSE-------------------------------------
library(here)
load(here("data", "fit_objects", "cv_sl_fit.Rdata"))
load(here("data", "fit_objects", "runtime_cv_sl_fit.Rdata"))
runtime_cv_sl_fit


## ----cvsl-risk-summary--------------------------------------------------------
cv_sl_fit$cv_risk[,c(1:3)]


## ----cvsl-risk-summary-handbook, echo = FALSE---------------------------------
if (knitr::is_latex_output()) {
cv_sl_fit$cv_risk[,c(1:3)] %>%
kable(format = "latex")
} else if (knitr::is_html_output()) {
cv_sl_fit$cv_risk[,c(1:3)] %>%
kable() %>%
kableExtra:::kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")
}


## ----sl-revere-risk-----------------------------------------------------------
cv_risk_table_revere <- sl_fit$cv_risk(
cv_risk_w_sl_revere <- sl_fit$cv_risk(
eval_fun = loss_squared_error, get_sl_revere_risk = TRUE
)


## ----sl-revere-risk-summary, eval = FALSE-------------------------------------
## cv_risk_table_revere[,c(1:3)]
## cv_risk_w_sl_revere[,c(1:3)]


## ----sl-revere-risk-handbook, eval = FALSE, echo = FALSE----------------------
## if (knitr::is_latex_output()) {
## cv_risk_table_revere[,c(1:3)] %>%
## cv_risk_w_sl_revere[,c(1:3)] %>%
## kable(format = "latex")
## } else if (knitr::is_html_output()) {
## cv_risk_table_revere[,c(1:3)] %>%
## cv_risk_w_sl_revere[,c(1:3)] %>%
## kable() %>%
## kableExtra:::kable_styling(fixed_thead = TRUE) %>%
## scroll_box(width = "100%", height = "300px")
Expand Down Expand Up @@ -409,13 +447,13 @@ sl_revere_risk_list <- lapply(seq_along(task$folds), function(i){

# get predicted outcomes for fold i's metalevel dataset, using the fitted
# metalearner, cv_meta_fit
sl_revere_v_preds <- sl_fit$fit_object$cv_meta_fit$predict(task = v_meta_task)
sl_revere_v_preds <- sl_fit$fit_object$cv_meta_fit$predict(task=v_meta_task)
# note: cv_meta_fit was trained on the metalevel dataset, which contains the
# candidates' cv predictions and validation dataset outcomes across ALL folds,
# so cv_meta_fit has already seen fold i's validation dataset outcomes.

# calculate predictive performance for fold i for the SL
eval_function <- loss_squared_error # valid for estimation of conditional mean
eval_function <- loss_squared_error # valid for estimation of conditional mean
# note: by evaluating the predictive performance of the SL using outcomes
# that were already seen by the metalearner, this is not a cross-validated
# measure of predictive performance for the SL.
Expand All @@ -430,11 +468,11 @@ sl_revere_risk_byhand <- mean(unlist(sl_revere_risk_list))
sl_revere_risk_byhand

# check that our calculation by hand equals what is output in cv_risk_table_revere
sl_revere_risk <- as.numeric(cv_risk_table_revere[learner == "SuperLearner", "MSE"])
identical(sl_revere_risk_byhand, sl_revere_risk)
sl_revere_risk <- as.numeric(cv_risk_w_sl_revere[learner=="SuperLearner","MSE"])
identical(sl_revere_risk, sl_revere_risk_byhand)


## ----make-Lrnr_cv_selector----------------------------------------------------
## ----make-Lrnr-cv-selector----------------------------------------------------
cv_selector <- Lrnr_cv_selector$new(eval_function = loss_squared_error)


Expand Down Expand Up @@ -517,7 +555,7 @@ task$data[some_rows_with_missingness,
colSums(is.na(task$data))


## ----fruit--------------------------------------------------------------------
## ----kitty--------------------------------------------------------------------
cats <- c("calico", "tabby", "cow", "ragdoll", "mancoon", "dwarf", "calico")
cats <- factor(cats)
cats_onehot <- factor_to_indicators(cats)
Expand Down Expand Up @@ -670,12 +708,12 @@ if (knitr::is_latex_output()) {
}


## ----varimp-plot, out.width = "100%"------------------------------------------
## ----varimp-plot, out.width = '60%', fig.asp = .62----------------------------
# plot variable importance
importance_plot(x = washb_varimp)


## ----cde_using_locscale, eval = FALSE-----------------------------------------
## ----cde-using-locscale, eval = FALSE-----------------------------------------
## # semiparametric density estimator with homoscedastic errors (HOSE)
## hose_hal_lrnr <- Lrnr_density_semiparametric$new(
## mean_learner = Lrnr_hal9001$new()
Expand All @@ -693,7 +731,7 @@ importance_plot(x = washb_varimp)
## )


## ----cde_using_pooledhaz, eval = FALSE----------------------------------------
## ----cde-using-pooledhaz, eval = FALSE----------------------------------------
## # learners used for conditional densities for (g_n)
## haldensify_lrnr <- Lrnr_haldensify$new(
## n_bins = c(5, 10)
Expand Down
22 changes: 11 additions & 11 deletions R_code/08-tmle3mopttx.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,19 @@ delta_learner <- Lrnr_sl$new(
learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner, delta_Y=delta_learner)


## ----spec_init_missingness----------------------------------------------------
# initialize a tmle specification
tmle_spec_cat_miss <- tmle3_mopttx_blip_revere(
V = c("W1", "W2", "W3", "W4"), type = "blip2",
learners = learner_list, maximize = TRUE, complex = TRUE,
realistic = FALSE
)
## ----spec_init_missingness, eval = FALSE--------------------------------------
## # initialize a tmle specification
## tmle_spec_cat_miss <- tmle3_mopttx_blip_revere(
## V = c("W1", "W2", "W3", "W4"), type = "blip2",
## learners = learner_list, maximize = TRUE, complex = TRUE,
## realistic = FALSE
## )


## ----fit_tmle_auto2, eval=T---------------------------------------------------
# fit the TML estimator
fit_cat_miss <- tmle3(tmle_spec_cat_miss, data_missing, node_list, learner_list)
fit_cat_miss
## ----fit_tmle_auto2, eval = FALSE---------------------------------------------
## # fit the TML estimator
## fit_cat_miss <- tmle3(tmle_spec_cat_miss, data_missing, node_list, learner_list)
## fit_cat_miss


## ----spec_init_Qlearning2, eval=FALSE-----------------------------------------
Expand Down

0 comments on commit 773e428

Please sign in to comment.