Skip to content

Commit

Permalink
added multiple-predictions vignette article
Browse files Browse the repository at this point in the history
  • Loading branch information
seanrmeyer committed Apr 16, 2022
1 parent 1424c87 commit c5f68bc
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
inst/doc
/doc/
/Meta/
.DS_Store
289 changes: 289 additions & 0 deletions vignettes/articles/multiple-predictions.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
---
title: "single-prediction"
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```

# install packages

```{r eval=FALSE}
install.packages("snakecase")
install.packages("h2o")
install.packages("remotes")
remotes::install_github('ML4LHS/gpmodels')
```

# setup

```{r setup}
library(tidyr)
library(dplyr)
library(readr)
library(magrittr)
library(snakecase)
library(gpmodels)
library(lubridate)
library(ggplot2)
```

https://data.world/informatics-edu/synthea-synthetic-patient-informationus

https://synthea.mitre.org/downloads

# Import data

```{r}
temp <- tempfile()
download.file("https://synthetichealth.github.io/synthea-sample-data/downloads/synthea_sample_data_csv_apr2020.zip", temp)
unzip(temp)
```

# Format data

```{r}
patients = read_csv("./csv/patients.csv") %>%
janitor::clean_names() %>%
# mutate(description = to_snake_case(description)) %>%
select(patient = id, birthdate, deathdate, marital, race, ethnicity, gender)
patients %>% str()
conditions = read_csv("./csv/conditions.csv") %>%
janitor::clean_names() %>%
mutate(description = to_snake_case(description), category = "condition", value = 1) %>%
select(patient, encounter, time = start, category, description = description)
conditions %>% str()
observations = read_csv("./csv/observations.csv") %>%
janitor::clean_names() %>%
unite(col = "description", c("description", "units")) %>%
mutate(description = to_snake_case(description), category = "observation") %>%
select(patient, encounter, time = date, category, description, value)
observations %>% str()
encounters = read_csv("./csv/encounters.csv") %>%
janitor::clean_names() %>%
filter(encounterclass == "inpatient") %>%
mutate(description = to_snake_case(description)) %>%
select(patient, encounter = id, start, stop, encounterclass, encounter_descr = description)
```

# Define static outcome

> A static outcome must be defined and joined with the fixed data
```{r}
outcome_pne = conditions %>%
filter(grepl("pneumonia", description)) %>%
group_by(encounter) %>%
slice_min(time) %>%
ungroup() %>%
select(patient, encounter, pneumonia_time = time)
outcome_pne %>% str()
```

# Build cohort and demographic data (fixed data)

```{r}
fixed_data = patients %>%
inner_join(encounters %>% filter(encounterclass == "inpatient") %>% unique() %>% arrange(desc(start))) %>%
left_join(outcome_pne %>% select(patient, pneumonia_time)) %>%
mutate(pneumonia = if_else(!is.na(pneumonia_time), "PNA", "No_PNA") %>% factor(levels = c("PNA", "No_PNA"))) %>%
mutate(pneumonia = if_else(!is.na(pneumonia_time), 1, 0, 0) %>% as.factor()) %>%
mutate(pred_start = start, pred_stop = coalesce(deathdate, pneumonia_time, stop)) %>%
select(-c(start, stop)) %>%
filter(!is.na(pred_start), # prediction start must be present
pred_start <= pred_stop) %>% # prediction start should never be greater than prediction stop
distinct(across(c(encounter)), .keep_all = TRUE)
```

# Temporal data

```{r}
temporal_data = bind_rows(conditions, observations) %>%
select(-encounter) %>%
inner_join(fixed_data %>% select(patient, encounter)) %>%
mutate(time = as_datetime(time))
```

# Define data model

> start predictions: admission
> end predictions: discharge
> prediction interval: 6 hours
```{r}
tf = gpmodels::time_frame(fixed_data = fixed_data,
temporal_data = temporal_data, # timestamped data
fixed_id = 'encounter',
fixed_start = 'pred_start',
fixed_end = 'pred_stop',
temporal_id = 'encounter',
temporal_time = 'time',
temporal_variable = 'description',
temporal_value = ,'value',
temporal_category = 'category',
step = hours(6), # how often to make predictions
max_length = days(7), # maximum time limit to make predictions (e.g. 7 days inpatient)
output_folder = './gpmodels_out/',
create_folder = TRUE,
chunk_size = 2000)
```

```{r}
tf$fixed_data %>%
janitor::tabyl(pneumonia)
```

```{r}
tf = tf %>%
pre_dummy_code()
```

# Define custom functions

```{r}
presence = function (x) { if_else(length(x) > 0, 1, 0, 0) }
```

# Define predictors

```{r}
future::plan('multisession')
tf %>%
add_baseline_predictors(category = "condition",
lookback = months(12),
window = months(3),
stats = c(presence = presence))
tf %>%
add_baseline_predictors(category = "observation",
lookback = months(12),
# window = months(3),
stats = c(min = min,
max = max,
median = median))
tf %>%
add_rolling_predictors(category = "observation",
lookback = hours(12),
window = hours(6),
stats = c(min = min,
max = max,
median = median,
last = dplyr::last))
```


# Define rolling outcome

> pneumonia in the next 24 hours
```{r}
tf %>%
add_rolling_outcomes(variables = "pneumonia",
lookahead = hours(24),
stats = c(presence = presence))
```

# Randomly split patients into groups for cross-validation

```{r}
set.seed(1234)
folds = tf$fixed_data %>%
select(patient) %>%
distinct() %>%
group_by(patient) %>%
mutate(fold = sample.int(6, patient %>% length(), replace = TRUE))
folds %>% write_csv(file.path("folds.csv"))
folds %>% janitor::tabyl(fold)
```

# Combine prepped data

```{r}
gpm_prepped = tf %>%
combine_output() %>%
write_csv("prepped_data.csv")
```

# Build model

```{r}
response = "pneumonia"
predictors = setdiff(names(gpm_prepped), c(response,"outcome_pneumonia_presence_24", gpm_prepped %>% select(patient:pred_stop) %>% names()))
predictors %>% length()
# predictors
library(h2o)
h2o.init(nthreads=-1) # multiprocessing support
h2o.removeAll() # Clean slate - just in case the cluster was already running
```

# Prep modeling datasets

```{r}
training = gpm_prepped %>% inner_join(folds %>% filter(fold %in% 1:4)) %>%
mutate(pneumonia = pneumonia %>% as.factor())
validation = gpm_prepped %>% inner_join(folds %>% filter(fold == 5)) %>%
mutate(pneumonia = pneumonia %>% as.factor()) %>% select(-fold)
test = gpm_prepped %>% inner_join(folds %>% filter(fold == 6)) %>%
mutate(pneumonia = pneumonia %>% as.factor()) %>% select(-fold)
```

# Build and evaluate model

```{r}
gbm_pna_static = h2o.gbm(x = predictors,
y = response,
training_frame = training %>% as.h2o(),
fold_column = "fold",
validation_frame = validation %>% as.h2o(),
ntrees = 200,
learn_rate = 0.05,
learn_rate_annealing = 0.99,
score_tree_interval = 10,
stopping_rounds = 5,
stopping_tolerance = 1e-5,
stopping_metric = "AUC",
seed = 1234,
keep_cross_validation_predictions = TRUE)
gbm_pna_static %>%
h2o.performance(newdata = test %>% as.h2o()) %>%
h2o.auc() %>%
round(3)
```

```{r}
gbm_pna_static_importance = h2o.varimp(gbm_pna_static) %>%
arrange(desc(relative_importance))
gbm_pna_static_importance %>%
as.data.frame() %>%
head(20) %>%
mutate(relative_importance = relative_importance %>% round(2)) %>%
select(variable, relative_importance) %>%
arrange(desc(relative_importance)) %>%
ggplot(aes(x = reorder(variable, relative_importance), y = relative_importance)) +
geom_bar(stat = "identity") +
coord_flip() +
xlab("")
```

# References

Jason Walonoski, Mark Kramer, Joseph Nichols, Andre Quina, Chris Moesel, Dylan Hall, Carlton Duffett, Kudakwashe Dube, Thomas Gallagher, Scott McLachlan, Synthea: An approach, method, and software mechanism for generating synthetic patients and the synthetic electronic health care record, Journal of the American Medical Informatics Association, Volume 25, Issue 3, March 2018, Pages 230–238, https://doi.org/10.1093/jamia/ocx079
44 changes: 31 additions & 13 deletions vignettes/articles/single-prediction.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ temporal_data = bind_rows(conditions, observations) %>%

> start predictions: admission
> end predictions: discharge
> prediction interval: 6 hours
> prediction interval: 1 month (arbitrary for single predictions)
```{r}
tf = gpmodels::time_frame(fixed_data = fixed_data,
Expand All @@ -131,8 +131,8 @@ tf = gpmodels::time_frame(fixed_data = fixed_data,
temporal_variable = 'description',
temporal_value = ,'value',
temporal_category = 'category',
step = hours(6), # how often to make predictions
max_length = days(7), # maximum time limit to make predictions (e.g. 7 days inpatient)
step = months(1), # how often to make predictions
# max_length = days(7), # maximum time limit to make predictions (e.g. 7 days inpatient)
output_folder = './gpmodels_out/',
create_folder = TRUE,
chunk_size = 2000)
Expand All @@ -159,9 +159,20 @@ presence = function (x) { if_else(length(x) > 0, 1, 0, 0) }
```{r}
future::plan('multisession')
tf %>% add_baseline_predictors(category = "condition", lookback = months(12), window = months(3), stats = c(presence = presence))
tf %>% add_baseline_predictors(category = "observation", lookback = months(12), window = months(3), stats = c(min = min, max = max, median = median, last = dplyr::last))
tf %>%
add_baseline_predictors(category = "condition",
lookback = months(12),
window = months(3),
stats = c(presence = presence))
tf %>%
add_baseline_predictors(category = "observation",
lookback = months(12),
window = months(3),
stats = c(min = min,
max = max,
median = median,
last = dplyr::last))
```


Expand All @@ -170,7 +181,10 @@ tf %>% add_baseline_predictors(category = "observation", lookback = months(12),
> pneumonia in the next 24 hours
```{r}
tf %>% add_rolling_outcomes(variables = "pneumonia", lookahead = hours(24), stats = c(presence = presence))
tf %>%
add_rolling_outcomes(variables = "pneumonia",
lookahead = hours(24),
stats = c(presence = presence))
```

# Randomly split patients into groups for cross-validation
Expand All @@ -196,6 +210,11 @@ gpm_prepped = tf %>%
write_csv("prepped_data.csv")
```

```{r}
gpm_prepped %>% head()
```


# Build model

```{r}
Expand Down Expand Up @@ -248,16 +267,15 @@ gbm_pna_static %>%

```{r}
gbm_pna_static_importance = h2o.varimp(gbm_pna_static) %>%
arrange(desc(relative_importance)) %>%
rename(pc_relative_importance = relative_importance, pc_scaled_importance = scaled_importance, pc_percentage = percentage)
arrange(desc(relative_importance))
gbm_pna_static_importance %>%
as.data.frame() %>%
head(20) %>%
mutate(pc_relative_importance = pc_relative_importance %>% round(2)) %>%
select(variable, pc_relative_importance) %>%
arrange(desc(pc_relative_importance)) %>%
ggplot(aes(x = reorder(variable, pc_relative_importance), y = pc_relative_importance)) +
mutate(relative_importance = relative_importance %>% round(2)) %>%
select(variable, relative_importance) %>%
arrange(desc(relative_importance)) %>%
ggplot(aes(x = reorder(variable, relative_importance), y = relative_importance)) +
geom_bar(stat = "identity") +
coord_flip() +
xlab("")
Expand Down

0 comments on commit c5f68bc

Please sign in to comment.