Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 757: Vectorise GP stan code #742

Merged
merged 107 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
9c3b8dd
refactor gps along approach used by aki
seabbs Aug 9, 2024
725cbb3
add some skeleton unit tests
seabbs Aug 9, 2024
20f76bf
update model
seabbs Aug 9, 2024
478e834
update data chunk
seabbs Aug 9, 2024
32572a4
add r interface code
seabbs Aug 9, 2024
8ed3295
add some R side tests where missing for gp related code
seabbs Aug 9, 2024
f20f409
get tests passing
seabbs Aug 12, 2024
89de657
fix tests for create_gp_data
seabbs Aug 12, 2024
0fbbcba
fix tests
seabbs Aug 12, 2024
17f9992
update docs
seabbs Aug 12, 2024
5b3c0b4
fix tests warnings to error for uncommon mattern orders
seabbs Aug 12, 2024
602f28c
remove spurious plots
seabbs Aug 12, 2024
1e5eb4a
update EpiNow2 vignette
seabbs Aug 12, 2024
0bf3605
review kkernals
seabbs Aug 12, 2024
4f8c5bc
update and test
seabbs Aug 12, 2024
54161f0
correct scaling of L
seabbs Aug 12, 2024
092a794
rescale lengthscale
seabbs Aug 13, 2024
a4b04a8
change adapt-delta default to 0.9
seabbs Aug 13, 2024
d8a1894
expand inits for GP as causing issues as to close to 1/0
seabbs Aug 13, 2024
ebcfb89
get rid of normalisation and use unormalised lpdf where possible (equiv)
seabbs Aug 13, 2024
4702388
widen optimisation sweep to include delay weight default
seabbs Aug 13, 2024
706b9de
non-center random walk
seabbs Aug 13, 2024
92624fa
tune prior specification
seabbs Aug 13, 2024
657c27d
tune dispersion prior
seabbs Aug 13, 2024
0701b7b
tune phi
seabbs Aug 13, 2024
561bdb3
update vignette
seabbs Aug 13, 2024
a1344d4
get rid of rw change
seabbs Aug 13, 2024
ca14c13
revert Rt
seabbs Aug 13, 2024
16354c0
catch update_rt
seabbs Aug 13, 2024
0da8c90
add news
seabbs Aug 13, 2024
c28af7a
revert vignette changes
seabbs Aug 13, 2024
eb663b3
fix gp_opts tests
seabbs Aug 13, 2024
6247bfa
fix create tests
seabbs Aug 13, 2024
e2f4b13
update gp tests
seabbs Aug 14, 2024
137ecf2
skip tests as required on windows
seabbs Aug 14, 2024
299268a
fix linting
seabbs Aug 14, 2024
1b70596
constrain delay uncertainty
seabbs Aug 14, 2024
128c592
correct gp stan tests
seabbs Aug 14, 2024
6fd326f
fix GP test
seabbs Aug 14, 2024
b9153b6
put the deprecition warning behind a gate
seabbs Aug 14, 2024
e338331
refactor gps along approach used by aki
seabbs Aug 9, 2024
9297afd
add some skeleton unit tests
seabbs Aug 9, 2024
8dec848
update model
seabbs Aug 9, 2024
b503e82
update data chunk
seabbs Aug 9, 2024
e96a4fc
add r interface code
seabbs Aug 9, 2024
3cfc9aa
add some R side tests where missing for gp related code
seabbs Aug 9, 2024
783a766
get tests passing
seabbs Aug 12, 2024
e296a89
fix tests for create_gp_data
seabbs Aug 12, 2024
fc15818
fix tests
seabbs Aug 12, 2024
5edd8ef
update docs
seabbs Aug 12, 2024
282de44
fix tests warnings to error for uncommon mattern orders
seabbs Aug 12, 2024
d333a81
update EpiNow2 vignette
seabbs Aug 12, 2024
9e990ce
review kkernals
seabbs Aug 12, 2024
afc2520
update and test
seabbs Aug 12, 2024
a01c7c4
correct scaling of L
seabbs Aug 12, 2024
19da24c
rescale lengthscale
seabbs Aug 13, 2024
0fb1dc6
change adapt-delta default to 0.9
seabbs Aug 13, 2024
c0d0dbb
expand inits for GP as causing issues as to close to 1/0
seabbs Aug 13, 2024
3ff3cb9
get rid of normalisation and use unormalised lpdf where possible (equiv)
seabbs Aug 13, 2024
9dba492
widen optimisation sweep to include delay weight default
seabbs Aug 13, 2024
7c16cd6
non-center random walk
seabbs Aug 13, 2024
5d3801b
tune prior specification
seabbs Aug 13, 2024
fee0289
tune dispersion prior
seabbs Aug 13, 2024
fea377c
tune phi
seabbs Aug 13, 2024
89cafd6
update vignette
seabbs Aug 13, 2024
190c800
get rid of rw change
seabbs Aug 13, 2024
9bd1989
revert Rt
seabbs Aug 13, 2024
7090af2
catch update_rt
seabbs Aug 13, 2024
da5c1ea
add news
seabbs Aug 13, 2024
38ae947
revert vignette changes
seabbs Aug 13, 2024
0348039
fix gp_opts tests
seabbs Aug 13, 2024
883d6cb
fix create tests
seabbs Aug 13, 2024
f476353
update gp tests
seabbs Aug 14, 2024
1f39fa3
skip tests as required on windows
seabbs Aug 14, 2024
30eb30f
fix linting
seabbs Aug 14, 2024
c6328e4
constrain delay uncertainty
seabbs Aug 14, 2024
de9348e
correct gp stan tests
seabbs Aug 14, 2024
c47fd0f
fix GP test
seabbs Aug 14, 2024
92db47c
put the deprecition warning behind a gate
seabbs Aug 14, 2024
d767118
fix linting
seabbs Aug 14, 2024
239a710
Update NEWS.md
seabbs Aug 14, 2024
40aae17
merge
seabbs Aug 15, 2024
f781c3c
add linear kernel support
seabbs Aug 15, 2024
fa7d2f4
add docs and newa
seabbs Aug 15, 2024
ecf8965
integration tests and minor issues
seabbs Aug 15, 2024
e0c3849
fixes for periodic kernel dimension differences
seabbs Aug 15, 2024
242ccb6
drop linear kernel support
seabbs Aug 19, 2024
4cb3de7
lint space
seabbs Aug 19, 2024
adf9412
catch outstanding linear tests
seabbs Aug 19, 2024
a46bc6e
catch stan tests
seabbs Aug 19, 2024
20292a7
make the eecdf in convolve test less random
seabbs Aug 20, 2024
abe77f6
Update R/create.R
seabbs Aug 27, 2024
c8a8ca4
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 27, 2024
1b73919
Update NEWS.md
seabbs Aug 28, 2024
a42f31d
Update create.R - remove out of date gp type 3 check
seabbs Aug 28, 2024
5c39a86
Update opts.R - remove linear kernel references
seabbs Aug 28, 2024
83f372f
Update R/opts.R
seabbs Aug 28, 2024
93bb8c5
Update opts.R - fix review suggestions
seabbs Aug 28, 2024
1074124
Update opts.R - remove linear reference
seabbs Aug 28, 2024
37be7ff
Update estimate_infections.stan
seabbs Aug 28, 2024
94df110
Update tests/testthat/test-create_gp_data.R
seabbs Aug 28, 2024
09a8abb
Update NEWS.md
seabbs Aug 28, 2024
078d550
Update opts.R
seabbs Aug 28, 2024
71eb3be
Document
actions-user Aug 28, 2024
13f7f8a
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
7ff2590
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
7532c83
Merge branch 'main' into vectorise-spectral-density
seabbs Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ create_backcalc_data <- function(backcalc = backcalc_opts()) {
)
return(data)
}

#' Create Gaussian Process Data
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -394,11 +395,13 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_max = data$t - data$seeding_time - data$horizon,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
is.infinite(gp$matern_order), 0,
gp$matern_order == 1 / 2, 1,
gp$matern_order == 3 / 2, 2,
default = 3
)
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
gp$kernel == "matern" || gp$kernel == "ou", 2,
default = 2
),
nu = gp$matern_order,
w0 = gp$w0
)

gp_data <- c(data, gp_data)
Expand Down
50 changes: 24 additions & 26 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the structure of the approximate Gaussian
#' process. Custom settings can be supplied which override the defaults.
#' process. Custom settings can be supplied which override the defaults.
#'
#' @param ls_mean Numeric, defaults to 21 days. The mean of the lognormal
#' length scale.
Expand All @@ -418,22 +418,20 @@
#' @param ls_min Numeric, defaults to 0. The minimum value of the length scale.
#'
#' @param alpha_sd Numeric, defaults to 0.05. The standard deviation of the
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' magnitude parameter of the Gaussian process kernel. Should be approximately
#' the expected standard deviation of the logged Rt.
#'
#' @param kernel Character string, the type of kernel required. Currently
#' supporting the squared exponential kernel ("se", or "matern" with
#' 'matern_order = Inf'), 3 over 2 oder 5 over 2 Matern kernel ("matern", with
#' `matern_order = 3/2` (default) or `matern_order = 5/2`, respectively), or
#' Orstein-Uhlenbeck ("ou", or "matern" with 'matern_order = 1/2'). Defaulting
#' to the Matérn 3 over 2 kernel for a balance of smoothness and
#' discontinuities.
#' supporting the squared exponential kernel ("se"), periodic kernel
#' ("periodic"), Ornstein-Uhlenbeck kernel ("ou"), and Matern kernel ("matern").
#'
#' @param matern_order Numeric, defaults to 3/2. Order of Matérn Kernel to use.
#' Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' Common choices are 1/2, 3/2, and 5/2. If `kernel` is set
#' to "ou", `matern_order` will be automatically set to 1/2. Only used if
#' the kernel is set to "matern".
#'
#' @param matern_type Deprated; Numeric, defaults to 3/2. Order of Matérn Kernel
#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#' @param matern_type Deprecated; Numeric, defaults to 3/2. Order of Matérn Kernel

Check warning on line 433 in R/opts.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/opts.R,line=433,col=81,[line_length_linter] Lines should not be more than 80 characters. This line is 82 characters.
#' to use. Currently the orders 1/2, 3/2, 5/2 and Inf are supported.
#'
#' @param basis_prop Numeric, proportion of time points to use as basis
#' functions. Defaults to 0.2. Decreasing this value results in a decrease in
Expand All @@ -446,6 +444,9 @@
#' approximate Gaussian process. See (Riutort-Mayol et al. 2020
#' <https://arxiv.org/abs/2004.11408>) for advice on updating this default.
#'
#' @param w0 Numeric, defaults to 1.0. Fundamental frequency for periodic
#' kernel. Only used if `kernel` is set to "periodic".
#'
#' @importFrom rlang arg_match
#' @return A `<gp_opts>` object of settings defining the Gaussian process
#' @export
Expand All @@ -462,14 +463,15 @@
ls_min = 0,
ls_max = 60,
alpha_sd = 0.05,
kernel = c("matern", "se", "ou"),
kernel = c("matern", "se", "ou", "periodic"),
matern_order = 3 / 2,
seabbs marked this conversation as resolved.
Show resolved Hide resolved
matern_type) {
matern_type,
w0 = 1.0) {
lifecycle::deprecate_warn(
"1.6.0", "gp_opts(matern_type)", "gp_opts(matern_order)"
)
if (!missing(matern_type)) {
if (!missing(matern_order) && matern_type == matern_order) {
if (!missing(matern_order) && matern_type != matern_order) {
stop(
"Incompatible `matern_order` and `matern_type`. ",
"Use `matern_order` only."
Expand All @@ -480,20 +482,15 @@

kernel <- arg_match(kernel)
if (kernel == "se") {
if (!missing(matern_order) && is.finite(matern_order)) {
stop("Squared exponential kernel must have matern order unset or `Inf`.")
}
matern_order <- Inf
} else if (kernel == "ou") {
if (!missing(matern_order) && matern_order != 1 / 2) {
stop("Ornstein-Uhlenbeck kernel must have matern order unset or `1 / 2`.") ## nolint: nonportable_path_linter
}
matern_order <- 1 / 2
} else if (!(is.infinite(matern_order) ||
matern_order %in% c(1 / 2, 3 / 2, 5 / 2))) {
stop(
"only the Matern kernels of order `1 / 2`, `3 / 2`, `5 / 2` or `Inf` ", ## nolint: nonportable_path_linter
"are currently supported"
} else if (
!(is.infinite(matern_order) || matern_order %in% c(1 / 2, 3 / 2, 5 / 2))
) {
warning(
"Uncommon Matern kernel order. Common orders are `1 / 2`, `3 / 2`,",

Check warning on line 492 in R/opts.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/opts.R,line=492,col=8,[nonportable_path_linter] Use file.path() to construct portable file paths.
" and `5 / 2`"

Check warning on line 493 in R/opts.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/opts.R,line=493,col=8,[nonportable_path_linter] Use file.path() to construct portable file paths.
)
}

Expand All @@ -506,7 +503,8 @@
ls_max = ls_max,
alpha_sd = alpha_sd,
kernel = kernel,
matern_order = matern_order
matern_order = matern_order,
w0 = w0
)

attr(gp, "class") <- c("gp_opts", class(gp))
Expand Down
Binary file removed epinow-epinow-1.png
seabbs marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
Binary file removed epinow-regional_epinow-1.png
Binary file not shown.
Binary file removed epinow-regional_epinow_multiple-1.png
Binary file not shown.
4 changes: 3 additions & 1 deletion inst/stan/data/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
real<lower=0> ls_min; // Lower bound for the lengthscale
real<lower=0> ls_max; // Upper bound for the lengthscale
real alpha_sd; // standard deviation of the alpha gp kernal parameter
int gp_type; // type of gp, 0 = squared exponential, 1 = 3/2 matern
int gp_type; // type of gp, 0 = squared exponential, 1 = periodic, 2 = Matern
real nu; // smoothness parameter for Matern kernel (used if gp_type = 2)
real w0; // fundamental frequency for periodic kernel (used if gp_type = 1)
int stationary; // is underlying gaussian process first or second order
int fixed; // should a gaussian process be used
44 changes: 28 additions & 16 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ functions {
#include functions/generated_quantities.stan
}


data {
#include data/observations.stan
#include data/delays.stan
Expand All @@ -19,15 +18,15 @@ data {
#include data/observation_model.stan
}

transformed data{
transformed data {
// observations
int ot = t - seeding_time - horizon; // observed time
int ot_h = ot + horizon; // observed time + forecast horizon
// gaussian process
int noise_terms = setup_noise(
ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from
);
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms, gp_type == 1, w0); // basis function
// Rt
real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2));
real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2)));
Expand All @@ -41,38 +40,39 @@ transformed data{
}
}

parameters{
parameters {
// gaussian process
array[fixed ? 0 : 1] real<lower = ls_min,upper=ls_max> rho; // length scale of noise GP
array[fixed ? 0 : 1] real<lower = 0> alpha; // scale of of noise GP
array[fixed ? 0 : 1] real<lower = ls_min, upper = ls_max> rho; // length scale of noise GP
array[fixed ? 0 : 1] real<lower = 0> alpha; // scale of noise GP
vector[fixed ? 0 : M] eta; // unconstrained noise
// Rt
vector[estimate_r] log_R; // baseline reproduction number estimate (log)
array[estimate_r] real initial_infections ; // seed infections
array[estimate_r] real initial_infections; // seed infections
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate
array[bp_n > 0 ? 1 : 0] real<lower = 0> bp_sd; // standard deviation of breakpoint effect
array[bp_n] real bp_effects; // Rt breakpoint effects
// observation model

vector<lower = delay_params_lower>[delay_params_length] delay_params; // delay parameters
simplex[week_effect] day_of_week_simplex;// day of week reporting effect
array[obs_scale_sd > 0 ? 1 : 0] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
simplex[week_effect] day_of_week_simplex; // day of week reporting effect
array[obs_scale_sd > 0 ? 1 : 0] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
}

transformed parameters {
vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process
vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process
vector<lower = 0>[estimate_r > 0 ? ot_h : 0] R; // reproduction number
vector[t] infections; // latent infections
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;

// GP in noise - spectral densities
profile("update gp") {
if (!fixed) {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type, nu);
}
}

// Estimate latent infections
if (estimate_r) {
profile("gt") {
Expand Down Expand Up @@ -102,6 +102,7 @@ transformed parameters {
);
}
}

// convolve from latent infections to mean of observations
if (delay_id) {
vector[delay_type_max[delay_id] + 1] delay_rev_pmf;
Expand All @@ -119,12 +120,14 @@ transformed parameters {
} else {
reports = infections[(seeding_time + 1):t];
}

// weekly reporting effect
if (week_effect > 1) {
profile("day of the week") {
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
}
}

// scaling of reported cases by fraction observed
if (obs_scale) {
profile("scale") {
Expand All @@ -133,6 +136,7 @@ transformed parameters {
);
}
}

// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf;
Expand Down Expand Up @@ -161,13 +165,15 @@ model {
);
}
}
// penalised priors for delay distributions

// penalized priors for delay distributions
profile("delays lp") {
delays_lp(
delay_params, delay_params_mean, delay_params_sd, delay_params_groups,
delay_dist, delay_weight
);
}

if (estimate_r) {
// priors on Rt
profile("rt lp") {
Expand All @@ -177,12 +183,14 @@ model {
);
}
}

// prior observation scaling
if (obs_scale_sd > 0) {
profile("scale lp") {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
}
}

// observed reports from mean of reports (update likelihood)
if (likelihood) {
profile("report lp") {
Expand All @@ -196,11 +204,12 @@ model {

generated quantities {
array[ot_h] int imputed_reports;
vector[estimate_r > 0 ? 0: ot_h] gen_R;
vector[estimate_r > 0 ? 0 : ot_h] gen_R;
vector[ot_h - 1] r;
vector[return_likelihood ? ot : 0] log_lik;

profile("generated quantities") {
if (estimate_r == 0){
if (estimate_r == 0) {
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
Expand All @@ -216,10 +225,13 @@ generated quantities {
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
}

// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);

// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);

// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
Expand Down
Loading
Loading