Skip to content

Commit

Permalink
the sampler functions can now take a seed
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinsimpson committed Mar 6, 2024
1 parent 64fe38c commit bb46307
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions R/samplers.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
#' to singular.
#' @param draws matrix; user supplied posterior draws to be used when
#' `method = "user"`.
#' @param seed numeric; the random seed to use. If `NULL`, a random seed will
#' be generated without affecting the current state of R's RNG.
#' @param ... arguments passed to methods.
#'
#' @export
Expand All @@ -56,27 +58,33 @@

#' @export
#' @rdname post_draws
#' @importFrom withr with_seed with_preserve_seed
`post_draws.default` <- function(
model, n,
method = c("gaussian", "mh", "inla", "user"), mu = NULL, sigma = NULL,
n_cores = 1L, burnin = 1000, thin = 1, t_df = 40, rw_scale = 0.25,
index = NULL, frequentist = FALSE, unconditional = FALSE,
parametrized = TRUE, mvn_method = c("mvnfast", "mgcv"), draws = NULL, ...) {
parametrized = TRUE, mvn_method = c("mvnfast", "mgcv"), draws = NULL,
seed = NULL, ...) {
if (is.null(seed)) {
seed <- with_preserve_seed(runif(1))
}
# what posterior sampling are we using
method <- match.arg(method)
mvn_method <- match.arg(mvn_method)
betas <- switch(method,
"gaussian" = gaussian_draws(
"gaussian" = with_seed(seed, gaussian_draws(
model = model, n = n,
n_cores = n_cores, index = index, frequentist = frequentist,
unconditional = unconditional, parametrized = parametrized,
mvn_method = mvn_method, ...
),
"mh" = mh_draws(
)),
"mh" = with_seed(seed, mh_draws(
n = n, model = model, burnin = burnin,
thin = thin, t_df = t_df, rw_scale = rw_scale, index = index, ...
),
"inla" = .NotYetImplemented(),
)),
"inla" = stop("'method = \"inla\"' is not yet implemented.",
call. = FALSE),
"user" = user_draws(model = model, draws = draws, ...)
)
betas
Expand All @@ -91,25 +99,30 @@

#' @export
#' @rdname post_draws
#' @importFrom withr with_seed with_preserve_seed
`generate_draws.gam` <- function(
model, n, method = c("gaussian", "mh", "inla"),
mu = NULL, sigma = NULL, n_cores = 1L, burnin = 1000, thin = 1, t_df = 40,
rw_scale = 0.25, index = NULL, frequentist = FALSE, unconditional = FALSE,
mvn_method = c("mvnfast", "mgcv"), ...) {
mvn_method = c("mvnfast", "mgcv"), seed = NULL, ...) {
if (is.null(seed)) {
seed <- with_preserve_seed(runif(1))
}
# what posterior sampling are we using
method <- match.arg(method)
mvn_method <- match.arg(mvn_method)
betas <- switch(method,
"gaussian" = gaussian_draws(
"gaussian" = with_seed(seed, gaussian_draws(
model = model, n = n,
n_cores = n_cores, index = index, frequentist = frequentist,
unconditional = unconditional, mvn_method = mvn_method, ...
),
"mh" = mh_draws(
)),
"mh" = with_seed(seed, mh_draws(
n = n, model = model, burnin = burnin,
thin = thin, t_df = t_df, rw_scale = rw_scale, index = index
),
"inla" = .NotYetImplemented()
thin = thin, t_df = t_df, rw_scale = rw_scale, index = index, ...
)),
"inla" = stop("'method = \"inla\"' is not yet implemented.",
call. = FALSE)
)
betas
}
Expand Down

0 comments on commit bb46307

Please sign in to comment.