diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index 52f41ce52..8cf25c14a 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -189,6 +189,29 @@ vector report_log_lik(array[] int cases, vector reports, return(log_lik); } + +/** + * Custom safe version of the negative binomial sampler + * + * This function generates random samples of the negative binomial distribution + * whilst avoiding numerical overflows. + * + * @param mu Real value ofr mean mu. + * @param phi Real value for phi. + * + * @return A random sample from the negative binomial distribution. + */ +int neg_binomial_2_safe_rng(real mu, real phi) { + if (mu < 1e-8) { + return(0); + } else if (phi > 1e4) { + return(poisson_rng(mu > 1e8 ? 1e8 : mu)); + } else { + real gamma_rate = gamma_rng(phi, phi / mu); + return(poisson_rng(gamma_rate > 1e8 ? 1e8 : gamma_rate)); + } +} + /** * Generate random samples of reported cases * @@ -209,16 +232,7 @@ array[] int report_rng(vector reports, real rep_phi, int model_type) { } for (s in 1:t) { - if (reports[s] < 1e-8) { - sampled_reports[s] = 0; - } else { - // defer to poisson if phi is large, to avoid overflow - if (dispersion > 1e4) { - sampled_reports[s] = poisson_rng(reports[s] > 1e8 ? 1e8 : reports[s]); - } else { - sampled_reports[s] = neg_binomial_2_rng(reports[s] > 1e8 ? 1e8 : reports[s], dispersion); - } - } + sampled_reports[s] = neg_binomial_2_safe_rng(reports[s], dispersion); } return(sampled_reports); }