Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cmdstanr option #537

Merged
merged 31 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5d708bd
add cmdstan backend
sbfnk Jan 30, 2024
76c267e
add cmdstanr model
sbfnk Jan 30, 2024
277921c
generalise extract functions
sbfnk Jan 31, 2024
f6eda7a
create general fit_model function
sbfnk Jan 31, 2024
547ea8a
initial values for cmdstanr
sbfnk Jan 31, 2024
4d5f088
move stan model creation to create func
sbfnk Jan 31, 2024
db81caf
make dist_fit cmdstanr ready
sbfnk Jan 31, 2024
32e2cc8
make estimate_infections cmdstanr ready
sbfnk Jan 31, 2024
19fadcc
make estimate_secondary cmdstanr ready
sbfnk Jan 31, 2024
4eb4726
make estimate_truncation cmdstanr ready
sbfnk Jan 31, 2024
311d433
make simulate_infections cmdstanr ready
sbfnk Jan 31, 2024
6569e68
gitignore for binaries
sbfnk Jan 31, 2024
4ae5948
add cmdstanr as suggest
sbfnk Jan 31, 2024
028b499
make forecast_secondary cmdstanr ready
sbfnk Feb 1, 2024
f5b03bd
update stanargs test
sbfnk Feb 1, 2024
1cca4ca
make simulations work with updated options
sbfnk Feb 1, 2024
c8b5e09
add globals
sbfnk Feb 1, 2024
d7064ab
tests for cmdstanr backend
sbfnk Feb 1, 2024
ccb004f
update actions
sbfnk Feb 2, 2024
a808431
updates in response to lintr
sbfnk Feb 2, 2024
02dc4bf
don't use future_lapply for cmdstanr
sbfnk Feb 4, 2024
7c7112b
backend-specific success criteria
sbfnk Feb 5, 2024
df76685
use epinowcast action for installing cmdstan
sbfnk Feb 6, 2024
0032367
improve .gitignore for compiled stan files
sbfnk Feb 6, 2024
91a6b8e
deactivate testing on windows for now
sbfnk Feb 12, 2024
cac3811
Revert "use epinowcast action for installing cmdstan"
sbfnk Feb 13, 2024
6593f5a
Apply suggestions from code review
sbfnk Feb 14, 2024
6b60299
match arguments in `stan_model`
sbfnk Feb 14, 2024
efba507
don't match args$method but explictly stop instead
sbfnk Feb 14, 2024
b594c3c
put choices in argument
sbfnk Feb 14, 2024
b1c54a4
render documentation for stan_model
sbfnk Feb 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions .github/workflows/R-CMD-as-cran-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,18 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
dependencies: NA
extra-packages: |
rcmdcheck
stan-dev/cmdstanr
testthat

- name: Install cmdstan
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
upload-snapshots: true
upload-snapshots: true
23 changes: 14 additions & 9 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ jobs:
steps:
- uses: actions/checkout@v4

- name: Install cmdstan Linux system dependencies
if: runner.os == 'Linux'
run: |
sudo apt-get update
sudo apt-get install -y libcurl4-openssl-dev || true
sudo apt-get install -y openmpi-bin openmpi-common libopenmpi-dev || true
sudo apt-get install -y libpng-dev || true
- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v2
Expand All @@ -65,8 +58,20 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check
dependencies: NA
extra-packages: |
dplyr
rmarkdown
rcmdcheck
stan-dev/cmdstanr
testthat

- name: Install cmdstan
if: runner.os != 'Windows'
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- uses: r-lib/actions/check-r-package@v2
with:
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/lint-only-changed-files.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
dependencies: NA
extra-packages: |
stan-dev/cmdstanr
any::gh
any::lintr
any::purrr
needs: check

- name: Add lintr options
run: |
Expand All @@ -44,4 +45,4 @@ jobs:
lintr::lint_package(exclusions = exclusions_list)
shell: Rscript {0}
env:
LINTR_ERROR_ON_LINT: true
LINTR_ERROR_ON_LINT: true
4 changes: 3 additions & 1 deletion .github/workflows/synthetic-validation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
dependencies: NA
extra-packages: |
here
dplyr
tidyr
scoringutils
Expand All @@ -48,4 +50,4 @@ jobs:
with:
name: fits
retention-days: 5
path: synthetic.rds
path: synthetic.rds
15 changes: 12 additions & 3 deletions .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,18 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::covr
needs: coverage
dependencies: NA
extra-packages: |
covr
stan-dev/cmdstanr
testthat

- name: Install cmdstan
run: |
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
cmdstanr::install_cmdstan(cores = 2, quiet = TRUE)
shell: Rscript {0}

- name: Test coverage
run: covr::codecov(quiet = FALSE)
shell: Rscript {0}
shell: Rscript {0}
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,8 @@ vignettes/results

# unused figures
man/figures/*.png

# exclude compiled stan files
inst/stan/*
!inst/stan/*/
!inst/stan/*.stan
6 changes: 5 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Imports:
lubridate,
methods,
patchwork,
posterior,
progressr,
purrr,
R.utils (>= 2.0.0),
Expand All @@ -122,6 +123,7 @@ Imports:
truncnorm,
utils
Suggests:
cmdstanr,
covr,
dplyr,
here,
Expand All @@ -143,13 +145,15 @@ LinkingTo:
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
StanHeaders (>= 2.26.0)
Additional_repositories:
https://mc-stan.org/r-packages/
Biarch: true
Config/testthat/edition: 3
Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export(estimates_by_report_date)
export(expose_stan_fns)
export(extract_CrIs)
export(extract_inits)
export(extract_samples)
export(extract_stan_param)
export(fix_dist)
export(forecast_infections)
Expand All @@ -66,6 +67,7 @@ export(make_conf)
export(map_prob_change)
export(obs_opts)
export(opts_list)
export(package_model)
export(plot_estimates)
export(plot_summary)
export(regional_epinow)
Expand All @@ -91,6 +93,8 @@ export(setup_target_folder)
export(simulate_infections)
export(simulate_secondary)
export(stan_opts)
export(stan_sampling_opts)
export(stan_vb_opts)
export(summarise_key_measures)
export(summarise_results)
export(trunc_opts)
Expand Down Expand Up @@ -133,6 +137,7 @@ importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
importFrom(data.table,setcolorder)
importFrom(data.table,setkey)
importFrom(data.table,setnames)
importFrom(data.table,setorder)
importFrom(data.table,setorderv)
Expand Down Expand Up @@ -184,6 +189,7 @@ importFrom(lifecycle,deprecate_warn)
importFrom(lubridate,days)
importFrom(lubridate,wday)
importFrom(patchwork,plot_layout)
importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
Expand Down
31 changes: 29 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -630,8 +630,14 @@ create_initial_conditions <- function(data) {
#'
#' @param data A list of stan data as created by [create_stan_data()]
#'
#' @param init Initial conditions passed to `{rstan}`. Defaults to "random" but
#' can also be a function (as supplied by [create_initial_conditions()]).
#' @param init Initial conditions passed to `{rstan}`. Defaults to "random"
#' (initial values randomly drawn between -2 and 2) but can also be a
#' function (as supplied by [create_initial_conditions()]).
#'
#' @param model Character, name of the model for which arguments are
#' to be created.
#' @param fixed_param Logical, defaults to `FALSE`. Should arguments be
#' created to sample from fixed parameters (used by simulation functions).
#'
#' @param verbose Logical, defaults to `FALSE`. Should verbose progress
#' messages be returned.
Expand All @@ -650,7 +656,28 @@ create_initial_conditions <- function(data) {
create_stan_args <- function(stan = stan_opts(),
data = NULL,
init = "random",
model = "estimate_infections",
fixed_param = FALSE,
verbose = FALSE) {
if (fixed_param) {
if (stan$backend == "rstan") {
stan$algorithm <- "Fixed_param"
} else if (stan$backend == "cmdstanr") {
stan$fixed_param <- TRUE
stan$adapt_delta <- NULL
stan$max_treedepth <- NULL
}
}
## generate stan model
if (is.null(stan$object)) {
stan$object <- stan_model(stan$backend, model)
stan$backend <- NULL
}
# cmdstanr doesn't have an init = "random" argument
if (is.character(init) && init == "random" &&
inherits(stan$object, "CmdStanModel")) {
init <- 2
}
# set up shared default arguments
args <- list(
data = data,
Expand Down
37 changes: 22 additions & 15 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@
#' @return A stan fit of an interval censored distribution
#' @author Sam Abbott
#' @export
#' @inheritParams stan_opts
#' @examples
#' \donttest{
#' # integer adjusted exponential model
Expand All @@ -221,7 +222,8 @@
#' )
#' }
dist_fit <- function(values = NULL, samples = 1000, cores = 1,
chains = 2, dist = "exp", verbose = FALSE) {
chains = 2, dist = "exp", verbose = FALSE,
backend = "rstan") {
if (samples < 1000) {
samples <- 1000
warning(sprintf("%s %s", "`samples` must be at least 1000.",
Expand All @@ -244,7 +246,7 @@
par_sigma = numeric(0)
)

model <- stanmodels$dist_fit
model <- stan_model(backend, "dist_fit")

Check warning on line 249 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L249

Added line #L249 was not covered by tests

if (dist == "exp") {
data$dist <- 0
Expand All @@ -268,16 +270,21 @@
}

# fit model
fit <- rstan::sampling(
model,
data = data,
iter = samples + 1000,
warmup = 1000,
control = list(adapt_delta = adapt_delta),
chains = chains,
cores = cores,
refresh = ifelse(verbose, 50, 0)
args <- create_stan_args(
stan = stan_opts(
model,
samples = samples,
warmup = 1000,
control = list(adapt_delta = adapt_delta),
chains = chains,
cores = cores,
backend = backend

Check warning on line 281 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L273-L281

Added lines #L273 - L281 were not covered by tests
),
data = data, verbose = verbose, model = "dist_fit"

Check warning on line 283 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L283

Added line #L283 was not covered by tests
)

fit <- fit_model(args, id = "dist_fit")

Check warning on line 286 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L286

Added line #L286 was not covered by tests

return(fit)
}

Expand Down Expand Up @@ -533,11 +540,11 @@

out <- list()
if (dist == "lognormal") {
out$mean_samples <- sample(rstan::extract(fit)$mu, samples)
out$sd_samples <- sample(rstan::extract(fit)$sigma, samples)
out$mean_samples <- sample(extract(fit)$mu, samples)
out$sd_samples <- sample(extract(fit)$sigma, samples)

Check warning on line 544 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L543-L544

Added lines #L543 - L544 were not covered by tests
} else if (dist == "gamma") {
alpha_samples <- sample(rstan::extract(fit)$alpha, samples)
beta_samples <- sample(rstan::extract(fit)$beta, samples)
alpha_samples <- sample(extract(fit)$alpha, samples)
beta_samples <- sample(extract(fit)$beta, samples)

Check warning on line 547 in R/dist.R

View check run for this annotation

Codecov / codecov/patch

R/dist.R#L546-L547

Added lines #L546 - L547 were not covered by tests
out$mean_samples <- alpha_samples / beta_samples
out$sd_samples <- sqrt(alpha_samples) / beta_samples
}
Expand Down
Loading
Loading