Skip to content

Commit

Permalink
estimate r from initial R (#923)
Browse files Browse the repository at this point in the history
* estimate r from initial R

* remove unneeded argument

* update tests

* update naming to be more consistent

* don't use sundials solver

* update snapshot

* estimate initial infections within model

* update simulation snapshot

* another snapshot update

* doc update

* update snapshot again

* try manual approach

* relax prior

* scale with early cases

* update snapshots

* fabs -> abs

* pass cases

* ensure mean init is >=1

* move to transformed data

* add source

* update news item

* fix simulations

* update tests

* update snapshots

* temporarily remove additional repos

* Revert "temporarily remove additional repos"

This reverts commit c62be58.

* touchstone: don't upgrade

* stabilise initial guess

* Revert "touchstone: don't upgrade"

This reverts commit d70e4e5.

* update sim snapshot

* try max instead of correction

* remove/rename

* version 2: testing for speed

* Revert "version 2: testing for speed"

This reverts commit 6a36058.
  • Loading branch information
sbfnk authored Jan 23, 2025
1 parent 439cf62 commit 9376b06
Show file tree
Hide file tree
Showing 13 changed files with 126 additions and 214 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs.
- The Gaussian Process lengthscale is now scaled internally by half the length of the time series. By @sbfnk in #890 and reviewed by #seabbs.
- A bug was fixed where `plot.dist_spec()` wasn't throwing an informative error due to an incomplete check for the max of the specified delay. By @jamesmbaazam in #858 and reviewed by @.
- Updated the early dynamics calculation to use the full linear model if available. Also changed the prior for initial infections to be approximately Poisson and the initial growth rate to the point estimate of the initial growth rate scaled linearly by the estimated initial infections term. By @sbfnk in #903 and reviewed by @seabbs and @SamuelBrand1
- Updated the early dynamics calculation to estimate growth from the initial reproduction number instead of a separate linear model. Also changed the prior calculation for initial infections to be a scaling factor of early case numbers adjusted by the growth estimate, instead a true number of initial infections. By @sbfnk in #923 (with initial exploration in #903) and reviewed by @seabbs and @SamuelBrand1.

## Package changes

Expand Down
57 changes: 1 addition & 56 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -440,54 +440,6 @@ create_forecast_data <- function(forecast = forecast_opts(), data) {
return(data)
}

#' Calculate prior infections and fit early growth
#'
#' @description Calculates the prior infections and growth rate based on the
#' first week's data.
#'
#' @param cases Numeric vector; the case counts from the input data.
#' @inheritParams create_stan_data
#' @return A list containing `initial_infections_estimate` and
#' `initial_growth_estimate`.
#' @keywords internal
estimate_early_dynamics <- function(cases, seeding_time) {
initial_period <- data.table::data.table(
confirm = cases[seq_len(min(7, seeding_time, length(cases)))],
t = seq_len(min(7, seeding_time, length(cases))) - 1
)[!is.na(confirm)]

# Calculate initial infections and growth estimate
if (seeding_time > 1 && nrow(initial_period) > 1) {
safe_lm <- purrr::safely(stats::lm)
log_linear_estimate <- safe_lm(log(confirm) ~ t, data = initial_period)[[1]]
initial_infections_estimate <- ifelse(
is.null(log_linear_estimate), 0, log_linear_estimate$coefficients[1]
)
initial_growth_estimate <- ifelse(
is.null(log_linear_estimate), 0, log_linear_estimate$coefficients[2]
)
} else {
initial_infections_estimate <- 0
initial_growth_estimate <- 0
}

# Calculate prior infections
if (initial_infections_estimate == 0) {
initial_infections_estimate <- log(
mean(initial_period$confirm, na.rm = TRUE)
)
if (is.na(initial_infections_estimate) ||
is.null(initial_infections_estimate)) {
initial_infections_estimate <- 0
}
}

return(list(
initial_infections_estimate = initial_infections_estimate,
initial_growth_estimate = initial_growth_estimate
))
}

#' Create Stan Data Required for estimate_infections
#'
#' @description`r lifecycle::badge("stable")`
Expand Down Expand Up @@ -553,11 +505,6 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
delay = stan_data$seeding_time, horizon = stan_data$horizon
)
)
# calculate prior infections and fit early growth
stan_data <- c(
stan_data,
estimate_early_dynamics(confirmed_cases, seeding_time)
)
# backcalculation settings
stan_data <- c(stan_data, create_backcalc_data(backcalc))
# gaussian process data
Expand Down Expand Up @@ -639,9 +586,7 @@ create_initial_conditions <- function(data) {
out$eta <- array(numeric(0))
}
if (data$estimate_r == 1) {
out$initial_infections <- array(
rnorm(1, data$initial_infections_estimate, 0.2)
)
out$initial_infections <- array(rnorm(1))
}

if (data$bp_n > 0) {
Expand Down
19 changes: 5 additions & 14 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,14 @@ simulate_infections <- function(estimates, R, initial_infections,
if (missing(seeding_time)) {
seeding_time <- sum(max(generation_time))
}
if (seeding_time > 1) {
## estimate initial growth from initial reproduction number if seeding time
## is greater than 1
initial_growth <- (R$R[1] - 1) / mean(generation_time)
## adjust initial infections for initial exponential growth
log_initial_infections <- log(initial_infections) -
(seeding_time - 1) * initial_growth
} else {
initial_growth <- numeric(0)
log_initial_infections <- log(initial_infections)
}

data <- list(
n = 1,
t = nrow(R) + seeding_time,
seeding_time = seeding_time,
future_time = 0,
initial_infections = array(log_initial_infections, dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
initial_infections = array(log(initial_infections), dim = c(1, 1)),
initial_as_scale = 0,
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
)
Expand Down Expand Up @@ -433,7 +422,9 @@ forecast_infections <- function(estimates,
draws <- map(draws, ~ as.matrix(.[nstart:nend, ]))

## prepare data for stan command
data <- c(list(n = dim(draws$R)[1]), draws, estimates$args)
data <- c(
list(n = dim(draws$R)[1], initial_as_scale = 1), draws, estimates$args
)

## allocate empty parameters
data <- allocate_empty(
Expand Down
2 changes: 0 additions & 2 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
int estimate_r; // should the reproduction no be estimated (1 = yes)
real initial_infections_estimate; // point estimate of initial infections
real initial_growth_estimate; // point estimate of initial growth rate
int bp_n; // no of breakpoints (0 = no breakpoints)
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
array[n, 1] real initial_infections; // initial logged infections
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth
int initial_as_scale; // whether to interpret initial infections as scaling

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
Expand Down
19 changes: 10 additions & 9 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ transformed data {
delay_types_groups, delay_max, delay_np_pmf_groups
);
}

// initial infections scaling (on the log scale)
real initial_infections_guess = fmax(
0,
log(mean(head(cases, num_elements(cases) > 7 ? 7 : num_elements(cases))))
);
}

parameters {
Expand All @@ -60,12 +66,6 @@ transformed parameters {
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate

if (num_elements(initial_growth) > 0) {
initial_growth[1] = initial_growth_estimate +
initial_infections_estimate - initial_infections[1];
}

// GP in noise - spectral densities
profile("update gp") {
Expand Down Expand Up @@ -108,8 +108,8 @@ transformed parameters {
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time, obs_scale, frac_obs
R, seeding_time, gt_rev_pmf, initial_infections, pop,
future_time, obs_scale, frac_obs, 1
);
}
} else {
Expand Down Expand Up @@ -210,7 +210,8 @@ model {
// priors on Rt
profile("rt lp") {
rt_lp(
initial_infections, bp_effects, bp_sd, bp_n, initial_infections_estimate
initial_infections, bp_effects, bp_sd, bp_n,
cases, initial_infections_guess
);
}
}
Expand Down
25 changes: 15 additions & 10 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,33 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
);
return(new_inf);
}

// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht, int obs_scale, real frac_obs) {
vector generate_infections(vector R, int uot, vector gt_rev_pmf,
array[] real initial_infections, int pop, int ht,
int obs_scale, real frac_obs, int initial_as_scale) {
// time indices and storage
int ot = num_elements(oR);
int ot = num_elements(R);
int nht = ot - ht;
int t = ot + uot;
vector[ot] R = oR;
real exp_adj_Rt;
vector[t] infections = rep_vector(0, t);
vector[ot] cum_infections;
vector[ot] infectiousness;
real growth = R_to_r(R[1], gt_rev_pmf, 1e-3);
// Initialise infections using daily growth
infections[1] = exp(initial_infections[1]);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
if (initial_as_scale) {
infections[1] = exp(initial_infections[1] - growth * uot);
if (obs_scale) {
infections[1] = infections[1] / frac_obs;
}
} else {
infections[1] = exp(initial_infections[1]);
}
if (uot > 1) {
real growth = exp(initial_growth[1]);
real exp_growth = exp(growth);
for (s in 2:uot) {
infections[s] = infections[s - 1] * growth;
infections[s] = infections[s - 1] * exp_growth;
}
}
// calculate cumulative infections
Expand Down
58 changes: 51 additions & 7 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,64 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps,
/**
* Calculate the log-probability of the reproduction number (Rt) priors
*
* @param initial_infections Array of initial infection values
* @param initial_infections_scale Array of initial infection values
* @param bp_effects Vector of breakpoint effects
* @param bp_sd Array of breakpoint standard deviations
* @param bp_n Number of breakpoints
* @param prior_infections Prior mean for initial infections
*/
void rt_lp(array[] real initial_infections, vector bp_effects,
array[] real bp_sd, int bp_n, real prior_infections) {
void rt_lp(array[] real initial_infections_scale, vector bp_effects,
array[] real bp_sd, int bp_n, array[] int cases,
real initial_infections_guess) {
//breakpoint effects on Rt
if (bp_n > 0) {
bp_sd[1] ~ normal(0, 0.1) T[0,];
bp_effects ~ normal(0, bp_sd[1]);
}
// initial infections
initial_infections ~ normal(prior_infections, sqrt(prior_infections));

initial_infections_scale ~ normal(initial_infections_guess, 2);
}

/**
* Helper function for calculating r from R using Newton's method
*
* Code is based on Julia code from
* https://github.com/CDCgov/Rt-without-renewal/blob/d6344cc6e451e3e6c4188e4984247f890ae60795/EpiAware/test/predictive_checking/fast_approx_for_r.jl
* under Apache license 2.0.
*
* @param R Reproduction number
* @param r growth rate
* @param pmf generation time probability mass function (first index: 0)
*/
real R_to_r_newton_step(real R, real r, vector pmf) {
int len = num_elements(pmf);
vector[len] zero_series = linspaced_vector(len, 0, len - 1);
vector[len] exp_r = exp(-r * zero_series);
real ret = (R * dot_product(pmf, exp_r) - 1) /
(- R * dot_product(pmf .* zero_series, exp_r));
return(ret);
}

/**
* Estimate the growth rate r from reproduction number R. Used in the model to
* estimate the initial growth rate using Newton's method.
*
* Code is based on Julia code from
* https://github.com/CDCgov/Rt-without-renewal/blob/d6344cc6e451e3e6c4188e4984247f890ae60795/EpiAware/test/predictive_checking/fast_approx_for_r.jl
* under Apache license 2.0.
*
* @param R reproduction number
* @param gt_rev_pmf reverse probability mass function of the generation time
* @param abs_tol absolute tolerance of the solver
*/
real R_to_r(real R, vector gt_rev_pmf, real abs_tol) {
int gt_len = num_elements(gt_rev_pmf);
vector[gt_len] gt_pmf = reverse(gt_rev_pmf);
real mean_gt = dot_product(gt_pmf, linspaced_vector(gt_len, 0, gt_len - 1));
real r = fmax((R - 1) / (R * mean_gt), -1);
real step = abs_tol + 1;
while (abs(step) > abs_tol) {
step = R_to_r_newton_step(R, r, gt_pmf);
r -= step;
}

return(r);
}
2 changes: 1 addition & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time, obs_scale, frac_obs[i]
pop, future_time, obs_scale, frac_obs[i], initial_as_scale
));

if (delay_id) {
Expand Down
23 changes: 0 additions & 23 deletions man/estimate_early_dynamics.Rd

This file was deleted.

Loading

0 comments on commit 9376b06

Please sign in to comment.