From 070a6db22e79a388686799020bdb3136d70205a4 Mon Sep 17 00:00:00 2001 From: rachaelvphillips Date: Mon, 15 Mar 2021 10:09:38 -0700 Subject: [PATCH] make code changes --- R_code/01-preface.R | 1 + R_code/02-roadmap.R | 1 + R_code/03-tlverse.R | 1 + R_code/04-data.R | 1 + R_code/05-origami.R | 5 +++-- R_code/06-sl3.R | 43 ++++++++++++++++++++------------------- R_code/07-tmle3.R | 1 + R_code/08-tmle3mopttx.R | 21 ++++++++++--------- R_code/09-tmle3shift.R | 1 + R_code/10-tmle3mediate.R | 1 + R_code/11-tmle3survival.R | 1 + R_code/99-primer_R6.R | 1 + 12 files changed, 45 insertions(+), 33 deletions(-) diff --git a/R_code/01-preface.R b/R_code/01-preface.R index e69de29..8b13789 100644 --- a/R_code/01-preface.R +++ b/R_code/01-preface.R @@ -0,0 +1 @@ + diff --git a/R_code/02-roadmap.R b/R_code/02-roadmap.R index 4c45389..cca11ce 100644 --- a/R_code/02-roadmap.R +++ b/R_code/02-roadmap.R @@ -18,3 +18,4 @@ tidy_dag <- tidy_dagitty(dag) # visualize DAG ggdag(tidy_dag) + theme_dag() + diff --git a/R_code/03-tlverse.R b/R_code/03-tlverse.R index c39cd42..be58d99 100644 --- a/R_code/03-tlverse.R +++ b/R_code/03-tlverse.R @@ -1,3 +1,4 @@ ## ----installation, eval=FALSE------------------------------------------------- ## install.packages("devtools") ## devtools::install_github("tlverse/tlverse") + diff --git a/R_code/04-data.R b/R_code/04-data.R index ff1c2e4..e0c8885 100644 --- a/R_code/04-data.R +++ b/R_code/04-data.R @@ -60,3 +60,4 @@ if (knitr::is_latex_output()) { } else { skim(nhefs_data) } + diff --git a/R_code/05-origami.R b/R_code/05-origami.R index c656fe7..5a0fe5c 100644 --- a/R_code/05-origami.R +++ b/R_code/05-origami.R @@ -16,7 +16,7 @@ washb_data <- fread( washb_data <- washb_data[1:30, ] head(washb_data) %>% kable() %>% - kableExtra::kable_styling(fixed_thead = TRUE) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -122,7 +122,7 @@ covars <- colnames(washb_data)[-which(names(washb_data) == outcome)] head(washb_data) %>% kable() %>% - kableExtra::kable_styling(fixed_thead = TRUE) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -285,3 +285,4 @@ mses <- cross_validate( ) mses$mse colMeans(mses$mse[, c("arima", "arima2")]) + diff --git a/R_code/06-sl3.R b/R_code/06-sl3.R index 2903823..1525df7 100644 --- a/R_code/06-sl3.R +++ b/R_code/06-sl3.R @@ -23,7 +23,7 @@ washb_data <- fread( ) head(washb_data) %>% kable() %>% - kableExtra::kable_styling(fixed_thead = T) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -236,13 +236,13 @@ head(sl_preds) ## ---- plot-predvobs-woohoo, eval=FALSE---------------------------------------- -## +## ## # df_plot <- data.frame(Observed = washb_data[["whz"]], Predicted = sl_preds, ## # count = seq(1:nrow(washb_data)) -## +## ## # df_plot_melted <- melt(df_plot, id.vars = "count", ## # measure.vars = c("Observed", "Predicted")) -## +## ## # ggplot(df_plot_melted, aes(value, count, color = variable)) + geom_point() @@ -262,7 +262,7 @@ CVsl <- CV_lrnr_sl( ) CVsl %>% kable(digits = 4) %>% - kableExtra::kable_styling(fixed_thead = TRUE) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -270,7 +270,7 @@ CVsl %>% washb_varimp <- importance(sl_fit, loss = loss_squared_error, type = "permute") washb_varimp %>% kable(digits = 4) %>% - kableExtra::kable_styling(fixed_thead = TRUE) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -295,7 +295,7 @@ chspred <- read_csv(file = db_data, col_names = TRUE) # take a quick peek head(chspred) %>% kable(digits = 4) %>% - kableExtra::kable_styling(fixed_thead = TRUE) %>% + kable_styling(fixed_thead = TRUE) %>% scroll_box(width = "100%", height = "300px") @@ -324,14 +324,14 @@ ist_task_CVsl <- make_sl3_Task( ## "https://raw.githubusercontent.com/benkeser/sllecture/master/chspred.csv" ## ) ## chspred <- read_csv(file = db_data, col_names = TRUE) -## +## ## # make task ## chspred_task <- make_sl3_Task( ## data = chspred, ## covariates = head(colnames(chspred), -1), ## outcome = "mi" ## ) -## +## ## # make learners ## glm_learner <- Lrnr_glm$new() ## lasso_learner <- Lrnr_glmnet$new(alpha = 1) @@ -344,11 +344,11 @@ ist_task_CVsl <- make_sl3_Task( ## ranger_learner <- Lrnr_ranger$new() ## svm_learner <- Lrnr_svm$new() ## xgb_learner <- Lrnr_xgboost$new() -## +## ## # screening ## screen_cor <- make_learner(Lrnr_screener_correlation) ## glm_pipeline <- make_learner(Pipeline, screen_cor, glm_learner) -## +## ## # stack learners together ## stack <- make_learner( ## Stack, @@ -357,17 +357,17 @@ ist_task_CVsl <- make_sl3_Task( ## curated_glm_learner, mean_learner, glm_fast_learner, ## ranger_learner, svm_learner, xgb_learner ## ) -## +## ## # make and train SL ## sl <- Lrnr_sl$new( ## learners = stack ## ) ## sl_fit <- sl$train(chspred_task) ## sl_fit$print() -## +## ## CVsl <- CV_lrnr_sl(sl_fit, chspred_task, loss_loglik_binomial) ## CVsl -## +## ## varimp <- importance(sl_fit, type = "permute") ## varimp %>% ## importance_plot( @@ -377,12 +377,12 @@ ist_task_CVsl <- make_sl3_Task( ## ----ex2-key, eval=FALSE------------------------------------------------------ ## library(ROCR) # for AUC calculation -## +## ## ist_data <- paste0( ## "https://raw.githubusercontent.com/tlverse/", ## "tlverse-handbook/master/data/ist_sample.csv" ## ) %>% fread() -## +## ## # stack ## ist_task <- make_sl3_Task( ## data = ist_data, @@ -390,7 +390,7 @@ ist_task_CVsl <- make_sl3_Task( ## covariates = colnames(ist_data)[-which(names(ist_data) == "DRSISC")], ## drop_missing_outcome = TRUE ## ) -## +## ## # learner library ## lrn_glm <- Lrnr_glm$new() ## lrn_lasso <- Lrnr_glmnet$new(alpha = 1) @@ -415,17 +415,17 @@ ist_task_CVsl <- make_sl3_Task( ## ), ## recursive = TRUE ## ) -## +## ## # SL ## sl <- Lrnr_sl$new(learners) ## sl_fit <- sl$train(ist_task) -## +## ## # AUC ## preds <- sl_fit$predict() ## obs <- c(na.omit(ist_data$DRSISC)) ## AUC <- performance(prediction(sl_preds, obs), measure = "auc")@y.values[[1]] ## plot(performance(prediction(sl_preds, obs), "tpr", "fpr")) -## +## ## # CVsl ## ist_task_CVsl <- make_sl3_Task( ## data = ist_data, @@ -440,7 +440,7 @@ ist_task_CVsl <- make_sl3_Task( ## ) ## CVsl <- CV_lrnr_sl(sl_fit, ist_task_CVsl, loss_loglik_binomial) ## CVsl -## +## ## # sl3 variable importance plot ## ist_varimp <- importance(sl_fit, type = "permute") ## ist_varimp %>% @@ -451,3 +451,4 @@ ist_task_CVsl <- make_sl3_Task( ## ----ex3-key, eval=FALSE------------------------------------------------------ ## # TODO + diff --git a/R_code/07-tmle3.R b/R_code/07-tmle3.R index dff5eda..3b75622 100644 --- a/R_code/07-tmle3.R +++ b/R_code/07-tmle3.R @@ -183,3 +183,4 @@ ist_data <- fread( "master/data/ist_sample.csv" ) ) + diff --git a/R_code/08-tmle3mopttx.R b/R_code/08-tmle3mopttx.R index 5131c87..8862460 100644 --- a/R_code/08-tmle3mopttx.R +++ b/R_code/08-tmle3mopttx.R @@ -297,14 +297,15 @@ b_learner <- create_mv_learners(learners = learners) learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner) -## ----spec_init_WASH----------------------------------------------------------- -# initialize a tmle specification -tmle_spec <- tmle3_mopttx_blip_revere( - V = c("momedu", "floor", "asset_refrig"), type = "blip2", - learners = learner_list, maximize = TRUE, complex = TRUE, - realistic = FALSE -) +## ----spec_init_WASH, eval=FALSE----------------------------------------------- +## # initialize a tmle specification +## tmle_spec <- tmle3_mopttx_blip_revere( +## V = c("momedu", "floor", "asset_refrig"), type = "blip2", +## learners = learner_list, maximize = TRUE, complex = TRUE, +## realistic = FALSE +## ) +## +## # fit the TML estimator +## fit <- tmle3(tmle_spec, data = washb_data, node_list, learner_list) +## fit -# fit the TML estimator -fit <- tmle3(tmle_spec, data = washb_data, node_list, learner_list) -fit diff --git a/R_code/09-tmle3shift.R b/R_code/09-tmle3shift.R index b2ec9aa..1c94cea 100644 --- a/R_code/09-tmle3shift.R +++ b/R_code/09-tmle3shift.R @@ -169,3 +169,4 @@ learner_list <- list(Y = sl_reg_lrnr, A = cv_hose_hal_lrnr) ## ----fit_tmle_wrapper_washb_shift, message=FALSE, warning=FALSE, eval=FALSE---- ## washb_tmle_fit <- tmle3(washb_vim_spec, washb_data, node_list, learner_list) ## washb_tmle_fit + diff --git a/R_code/10-tmle3mediate.R b/R_code/10-tmle3mediate.R index e69de29..8b13789 100644 --- a/R_code/10-tmle3mediate.R +++ b/R_code/10-tmle3mediate.R @@ -0,0 +1 @@ + diff --git a/R_code/11-tmle3survival.R b/R_code/11-tmle3survival.R index d2cdcac..47c4b7d 100644 --- a/R_code/11-tmle3survival.R +++ b/R_code/11-tmle3survival.R @@ -77,3 +77,4 @@ tmle_fit_manual <- fit_tmle3( ## ----------------------------------------------------------------------------- print(tmle_fit_manual) + diff --git a/R_code/99-primer_R6.R b/R_code/99-primer_R6.R index e69de29..8b13789 100644 --- a/R_code/99-primer_R6.R +++ b/R_code/99-primer_R6.R @@ -0,0 +1 @@ +