Skip to content

Commit

Permalink
make code changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rachaelvp committed Mar 15, 2021
1 parent 799728d commit 070a6db
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 33 deletions.
1 change: 1 addition & 0 deletions R_code/01-preface.R
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions R_code/02-roadmap.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ tidy_dag <- tidy_dagitty(dag)
# visualize DAG
ggdag(tidy_dag) +
theme_dag()

1 change: 1 addition & 0 deletions R_code/03-tlverse.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
## ----installation, eval=FALSE-------------------------------------------------
## install.packages("devtools")
## devtools::install_github("tlverse/tlverse")

1 change: 1 addition & 0 deletions R_code/04-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ if (knitr::is_latex_output()) {
} else {
skim(nhefs_data)
}

5 changes: 3 additions & 2 deletions R_code/05-origami.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ washb_data <- fread(
washb_data <- washb_data[1:30, ]
head(washb_data) %>%
kable() %>%
kableExtra::kable_styling(fixed_thead = TRUE) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


Expand Down Expand Up @@ -122,7 +122,7 @@ covars <- colnames(washb_data)[-which(names(washb_data) == outcome)]

head(washb_data) %>%
kable() %>%
kableExtra::kable_styling(fixed_thead = TRUE) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


Expand Down Expand Up @@ -285,3 +285,4 @@ mses <- cross_validate(
)
mses$mse
colMeans(mses$mse[, c("arima", "arima2")])

43 changes: 22 additions & 21 deletions R_code/06-sl3.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ washb_data <- fread(
)
head(washb_data) %>%
kable() %>%
kableExtra::kable_styling(fixed_thead = T) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


Expand Down Expand Up @@ -236,13 +236,13 @@ head(sl_preds)


## ---- plot-predvobs-woohoo, eval=FALSE----------------------------------------
##
##
## # df_plot <- data.frame(Observed = washb_data[["whz"]], Predicted = sl_preds,
## # count = seq(1:nrow(washb_data))
##
##
## # df_plot_melted <- melt(df_plot, id.vars = "count",
## # measure.vars = c("Observed", "Predicted"))
##
##
## # ggplot(df_plot_melted, aes(value, count, color = variable)) + geom_point()


Expand All @@ -262,15 +262,15 @@ CVsl <- CV_lrnr_sl(
)
CVsl %>%
kable(digits = 4) %>%
kableExtra::kable_styling(fixed_thead = TRUE) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


## ----varimp-------------------------------------------------------------------
washb_varimp <- importance(sl_fit, loss = loss_squared_error, type = "permute")
washb_varimp %>%
kable(digits = 4) %>%
kableExtra::kable_styling(fixed_thead = TRUE) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


Expand All @@ -295,7 +295,7 @@ chspred <- read_csv(file = db_data, col_names = TRUE)
# take a quick peek
head(chspred) %>%
kable(digits = 4) %>%
kableExtra::kable_styling(fixed_thead = TRUE) %>%
kable_styling(fixed_thead = TRUE) %>%
scroll_box(width = "100%", height = "300px")


Expand Down Expand Up @@ -324,14 +324,14 @@ ist_task_CVsl <- make_sl3_Task(
## "https://raw.githubusercontent.com/benkeser/sllecture/master/chspred.csv"
## )
## chspred <- read_csv(file = db_data, col_names = TRUE)
##
##
## # make task
## chspred_task <- make_sl3_Task(
## data = chspred,
## covariates = head(colnames(chspred), -1),
## outcome = "mi"
## )
##
##
## # make learners
## glm_learner <- Lrnr_glm$new()
## lasso_learner <- Lrnr_glmnet$new(alpha = 1)
Expand All @@ -344,11 +344,11 @@ ist_task_CVsl <- make_sl3_Task(
## ranger_learner <- Lrnr_ranger$new()
## svm_learner <- Lrnr_svm$new()
## xgb_learner <- Lrnr_xgboost$new()
##
##
## # screening
## screen_cor <- make_learner(Lrnr_screener_correlation)
## glm_pipeline <- make_learner(Pipeline, screen_cor, glm_learner)
##
##
## # stack learners together
## stack <- make_learner(
## Stack,
Expand All @@ -357,17 +357,17 @@ ist_task_CVsl <- make_sl3_Task(
## curated_glm_learner, mean_learner, glm_fast_learner,
## ranger_learner, svm_learner, xgb_learner
## )
##
##
## # make and train SL
## sl <- Lrnr_sl$new(
## learners = stack
## )
## sl_fit <- sl$train(chspred_task)
## sl_fit$print()
##
##
## CVsl <- CV_lrnr_sl(sl_fit, chspred_task, loss_loglik_binomial)
## CVsl
##
##
## varimp <- importance(sl_fit, type = "permute")
## varimp %>%
## importance_plot(
Expand All @@ -377,20 +377,20 @@ ist_task_CVsl <- make_sl3_Task(

## ----ex2-key, eval=FALSE------------------------------------------------------
## library(ROCR) # for AUC calculation
##
##
## ist_data <- paste0(
## "https://raw.githubusercontent.com/tlverse/",
## "tlverse-handbook/master/data/ist_sample.csv"
## ) %>% fread()
##
##
## # stack
## ist_task <- make_sl3_Task(
## data = ist_data,
## outcome = "DRSISC",
## covariates = colnames(ist_data)[-which(names(ist_data) == "DRSISC")],
## drop_missing_outcome = TRUE
## )
##
##
## # learner library
## lrn_glm <- Lrnr_glm$new()
## lrn_lasso <- Lrnr_glmnet$new(alpha = 1)
Expand All @@ -415,17 +415,17 @@ ist_task_CVsl <- make_sl3_Task(
## ),
## recursive = TRUE
## )
##
##
## # SL
## sl <- Lrnr_sl$new(learners)
## sl_fit <- sl$train(ist_task)
##
##
## # AUC
## preds <- sl_fit$predict()
## obs <- c(na.omit(ist_data$DRSISC))
## AUC <- performance(prediction(sl_preds, obs), measure = "auc")@y.values[[1]]
## plot(performance(prediction(sl_preds, obs), "tpr", "fpr"))
##
##
## # CVsl
## ist_task_CVsl <- make_sl3_Task(
## data = ist_data,
Expand All @@ -440,7 +440,7 @@ ist_task_CVsl <- make_sl3_Task(
## )
## CVsl <- CV_lrnr_sl(sl_fit, ist_task_CVsl, loss_loglik_binomial)
## CVsl
##
##
## # sl3 variable importance plot
## ist_varimp <- importance(sl_fit, type = "permute")
## ist_varimp %>%
Expand All @@ -451,3 +451,4 @@ ist_task_CVsl <- make_sl3_Task(

## ----ex3-key, eval=FALSE------------------------------------------------------
## # TODO

1 change: 1 addition & 0 deletions R_code/07-tmle3.R
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ ist_data <- fread(
"master/data/ist_sample.csv"
)
)

21 changes: 11 additions & 10 deletions R_code/08-tmle3mopttx.R
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,15 @@ b_learner <- create_mv_learners(learners = learners)
learner_list <- list(Y = Q_learner, A = g_learner, B = b_learner)


## ----spec_init_WASH-----------------------------------------------------------
# initialize a tmle specification
tmle_spec <- tmle3_mopttx_blip_revere(
V = c("momedu", "floor", "asset_refrig"), type = "blip2",
learners = learner_list, maximize = TRUE, complex = TRUE,
realistic = FALSE
)
## ----spec_init_WASH, eval=FALSE-----------------------------------------------
## # initialize a tmle specification
## tmle_spec <- tmle3_mopttx_blip_revere(
## V = c("momedu", "floor", "asset_refrig"), type = "blip2",
## learners = learner_list, maximize = TRUE, complex = TRUE,
## realistic = FALSE
## )
##
## # fit the TML estimator
## fit <- tmle3(tmle_spec, data = washb_data, node_list, learner_list)
## fit

# fit the TML estimator
fit <- tmle3(tmle_spec, data = washb_data, node_list, learner_list)
fit
1 change: 1 addition & 0 deletions R_code/09-tmle3shift.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,4 @@ learner_list <- list(Y = sl_reg_lrnr, A = cv_hose_hal_lrnr)
## ----fit_tmle_wrapper_washb_shift, message=FALSE, warning=FALSE, eval=FALSE----
## washb_tmle_fit <- tmle3(washb_vim_spec, washb_data, node_list, learner_list)
## washb_tmle_fit

1 change: 1 addition & 0 deletions R_code/10-tmle3mediate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

1 change: 1 addition & 0 deletions R_code/11-tmle3survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ tmle_fit_manual <- fit_tmle3(

## -----------------------------------------------------------------------------
print(tmle_fit_manual)

1 change: 1 addition & 0 deletions R_code/99-primer_R6.R
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

0 comments on commit 070a6db

Please sign in to comment.