From bb463076b27d6f574a43a5e01fc18adc1a6f5081 Mon Sep 17 00:00:00 2001 From: Gavin Simpson Date: Wed, 6 Mar 2024 17:26:23 +0100 Subject: [PATCH] the sampler functions can now take a seed --- R/samplers.R | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/R/samplers.R b/R/samplers.R index cbfff7dcd..407215aba 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -47,6 +47,8 @@ #' to singular. #' @param draws matrix; user supplied posterior draws to be used when #' `method = "user"`. +#' @param seed numeric; the random seed to use. If `NULL`, a random seed will +#' be generated without affecting the current state of R's RNG. #' @param ... arguments passed to methods. #' #' @export @@ -56,27 +58,33 @@ #' @export #' @rdname post_draws +#' @importFrom withr with_seed with_preserve_seed `post_draws.default` <- function( model, n, method = c("gaussian", "mh", "inla", "user"), mu = NULL, sigma = NULL, n_cores = 1L, burnin = 1000, thin = 1, t_df = 40, rw_scale = 0.25, index = NULL, frequentist = FALSE, unconditional = FALSE, - parametrized = TRUE, mvn_method = c("mvnfast", "mgcv"), draws = NULL, ...) { + parametrized = TRUE, mvn_method = c("mvnfast", "mgcv"), draws = NULL, + seed = NULL, ...) { + if (is.null(seed)) { + seed <- with_preserve_seed(runif(1)) + } # what posterior sampling are we using method <- match.arg(method) mvn_method <- match.arg(mvn_method) betas <- switch(method, - "gaussian" = gaussian_draws( + "gaussian" = with_seed(seed, gaussian_draws( model = model, n = n, n_cores = n_cores, index = index, frequentist = frequentist, unconditional = unconditional, parametrized = parametrized, mvn_method = mvn_method, ... - ), - "mh" = mh_draws( + )), + "mh" = with_seed(seed, mh_draws( n = n, model = model, burnin = burnin, thin = thin, t_df = t_df, rw_scale = rw_scale, index = index, ... - ), - "inla" = .NotYetImplemented(), + )), + "inla" = stop("'method = \"inla\"' is not yet implemented.", + call. = FALSE), "user" = user_draws(model = model, draws = draws, ...) ) betas @@ -91,25 +99,30 @@ #' @export #' @rdname post_draws +#' @importFrom withr with_seed with_preserve_seed `generate_draws.gam` <- function( model, n, method = c("gaussian", "mh", "inla"), mu = NULL, sigma = NULL, n_cores = 1L, burnin = 1000, thin = 1, t_df = 40, rw_scale = 0.25, index = NULL, frequentist = FALSE, unconditional = FALSE, - mvn_method = c("mvnfast", "mgcv"), ...) { + mvn_method = c("mvnfast", "mgcv"), seed = NULL, ...) { + if (is.null(seed)) { + seed <- with_preserve_seed(runif(1)) + } # what posterior sampling are we using method <- match.arg(method) mvn_method <- match.arg(mvn_method) betas <- switch(method, - "gaussian" = gaussian_draws( + "gaussian" = with_seed(seed, gaussian_draws( model = model, n = n, n_cores = n_cores, index = index, frequentist = frequentist, unconditional = unconditional, mvn_method = mvn_method, ... - ), - "mh" = mh_draws( + )), + "mh" = with_seed(seed, mh_draws( n = n, model = model, burnin = burnin, - thin = thin, t_df = t_df, rw_scale = rw_scale, index = index - ), - "inla" = .NotYetImplemented() + thin = thin, t_df = t_df, rw_scale = rw_scale, index = index, ... + )), + "inla" = stop("'method = \"inla\"' is not yet implemented.", + call. = FALSE) ) betas }