Skip to content

Commit

Permalink
add file argument to bmm()
Browse files Browse the repository at this point in the history
  • Loading branch information
venpopov committed Mar 28, 2024
1 parent 21706c5 commit 6cfb7f9
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 7 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Imports:
stats,
matrixStats,
crayon,
methods
methods,
fs
URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/
BugReports: https://github.com/venpopov/bmm/issues
Additional_repositories:
Expand Down
30 changes: 27 additions & 3 deletions R/bmm.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@
#' "rstan" or "cmdstanr". If NULL (the default), "cmdstanr" will be used if
#' the cmdstanr package is installed, otherwise "rstan" will be used. You can
#' set the default backend using global `options(brms.backend = "rstan"/"cmdstanr")`
#' @param file Either `NULL` or a character string. If a string, the fitted
#' model object is saved via [saveRDS] in a file named after the string. The
#' `.rds extension is added automatically. If the file already exists, `bmm`
#' will load and return the saved model object. Unless you specify the
#' `file_refit` argument as well, the existing files won't be overwritten, you
#' have to manually remove the file in order to refit and save the model under
#' an existing file name. The file name is stored in the `bmmfit` object for
#' later usage. If the directory of the file does not exist, it will be created.
#' @param file_compress Logical or a character string, specifying one of the
#' compression algorithms supported by [saveRDS] when saving
#' the fitted model object.
#' @param file_refit Logical. Modifies when the fit stored via the `file` argument is
#' re-used. Can be set globally for the current \R session via the
#' `"bmm.file_refit"` option (see [options]). If `TRUE` (the default), the
#' model is re-used if the file exists. If `FALSE`, the model is re-fitted. Note
#' that unlike `brms`, there is no "on_change" option
#' @param ... Further arguments passed to [brms::brm()] or Stan. See the
#' description of [brms::brm()] for more details
#'
Expand Down Expand Up @@ -89,10 +105,15 @@ bmm <- function(formula, data, model,
prior = NULL,
sort_data = getOption('bmm.sort_data', "check"),
silent = getOption('bmm.silent', 1),
backend = getOption('brms.backend', NULL), ...) {
backend = getOption('brms.backend', NULL),
file = NULL, file_compress = TRUE,
file_refit = getOption('bmm.file_refit', FALSE), ...) {
deprecated_args(...)
dots <- list(...)

x <- read_bmmfit(file, file_refit)
if (!is.null(x)) return(x)

# set temporary global options and return modified arguments for brms
configure_opts <- nlist(sort_data, silent, backend, parallel = dots$parallel,
cores = dots$cores)
Expand All @@ -116,8 +137,11 @@ bmm <- function(formula, data, model,
fit <- call_brm(fit_args)

# model post-processing
postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula,
configure_opts = configure_opts)
fit <- postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula,
configure_opts = configure_opts)

# save the fitted model object if !is.null
save_bmmfit(fit, file, compress = file_compress)
}


Expand Down
54 changes: 52 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,8 @@ identical.formula <- function(x, y, ...) {
#' printed in color. **Default: TRUE**
#' @param reset_options logical. If TRUE, the options will be reset to their
#' default values **Default: FALSE**
#' @param file_refit logical. If TRUE, bmm() will refit the model even if the
#' file argument is specified. **Default: FALSE**
#' @details The `bmm_options` function is used to view or change the current bmm
#' options. If no arguments are provided, the function will return the current
#' options. If arguments are provided, the function will change the options
Expand Down Expand Up @@ -493,7 +495,7 @@ identical.formula <- function(x, y, ...) {
#'
#' @export
bmm_options <- function(sort_data, parallel, default_priors, silent,
color_summary, reset_options = FALSE) {
color_summary, file_refit, reset_options = FALSE) {
opts <- ls()
stopif(!missing(sort_data) && sort_data != "check" && !is.logical(sort_data),
"sort_data must be one of TRUE, FALSE, or 'check'")
Expand All @@ -505,14 +507,17 @@ bmm_options <- function(sort_data, parallel, default_priors, silent,
"silent must be one of 0, 1, or 2")
stopif(!missing(color_summary) && !is.logical(color_summary),
"color_summary must be a logical value")
stopif(!missing(file_refit) && !is.logical(file_refit),
"file_refit must be a logical value")

# set default options if function is called for the first time or if reset_options is TRUE
if (reset_options) {
options(bmm.sort_data = "check",
bmm.parallel = FALSE,
bmm.default_priors = TRUE,
bmm.silent = 1,
bmm.color_summary = TRUE)
bmm.color_summary = TRUE,
bmm.file_refit = FALSE)
}

# change options if arguments are provided. get argument name and loop over non-missing arguments
Expand All @@ -529,6 +534,7 @@ bmm_options <- function(sort_data, parallel, default_priors, silent,
"\n parallel = ", getOption("bmm.parallel"),
"\n default_priors = ", getOption("bmm.default_priors"),
"\n silent = ", getOption("bmm.silent"),
"\n file_refit = ", getOption("bmm.file_refit"),
"\n color_summary = ", getOption("bmm.color_summary"), "\n")),
"For more information on these options or how to change them, see help(bmm_options).\n")
invisible(old_op)
Expand Down Expand Up @@ -663,6 +669,50 @@ deprecated_args <- function(...) {
See `help("brm")` for more information.')
}


read_bmmfit <- function(file, file_refit) {
file <- check_rds_file(file)
if (is.null(file) || file_refit) {
return(NULL)
}
dir <- dirname(file)
dir <- try(fs::dir_create(dir))
stopif(is_try_error(dir), "Cannot create directory for file.")

out <- suppressWarnings(try(readRDS(file), silent = TRUE))
if (!is_try_error(out)) {
if (!is_bmmfit(out)) {
stop2("Object loaded via 'file' is not of class 'bmmfit'.")
}
out$file <- file
} else {
out <- NULL
}
out
}

save_bmmfit <- function(x, file, compress) {
file <- check_rds_file(file)
x$file <- file
if (!is.null(file)) {
saveRDS(x, file, compress = compress)
}
x
}

check_rds_file <- function(file) {
if (is.null(file)) {
return(NULL)
}
stopif(!is.character(file), "'file' must be a character string.")
stopif(length(file) > 1, "'file' must be a single character string.")
ext <- fs::path_ext(file)
if (ext != "rds") {
file <- paste0(file, ".rds")
}
file
}

`%||%` <- function(a, b) {
if (!is.null(a)) a else b
}
2 changes: 1 addition & 1 deletion man/SDM.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/bmm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions man/bmm_options.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,56 @@ test_that("bmm_options works", {
options(op)
expect_equal(getOption('bmm.sort_data'), TRUE)
})


test_that("check_rds_file works", {
good_files <- list('a.rds', 'abc/a.rds', 'a', 'abc/a', 'a.M')
bad_files <- list(1, mean, c('a','b'), TRUE)

for (f in good_files) {
expect_silent(res <- check_rds_file(f))
expect_equal(fs::path_ext(res), 'rds')
}

for (f in bad_files) {
expect_error(check_rds_file(f))
}

expect_null(check_rds_file(NULL))
})

test_that("read_bmmfit works", {
mock_fit <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'),
backend = "mock", mock_fit = 1, rename = F)
file <- tempfile()
mock_fit$file <- paste0(file, '.rds')
saveRDS(mock_fit, paste0(file, '.rds'))
expect_equal(read_bmmfit(file, FALSE), mock_fit, ignore_function_env = TRUE,
ignore_formula_env = TRUE)

x = 1
saveRDS(x, paste0(file, '.rds'))
expect_error(read_bmmfit(file, FALSE), "not of class 'bmmfit'")
})

test_that("save_bmmfit works", {
file <- tempfile()
mock_fit <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'),
backend = "mock", mock_fit = 1, rename = F,
file = file)
rds_file <- paste0(file, '.rds')
expect_true(file.exists(rds_file))
expect_equal(readRDS(rds_file), mock_fit, ignore_function_env = TRUE,
ignore_formula_env = TRUE)

mock_fit2 <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'),
backend = "mock", mock_fit = 2, rename = F,
file = file)
expect_equal(mock_fit, mock_fit2)

# they should not be the same if file_refit = TRUE
mock_fit3 <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'),
backend = "mock", mock_fit = 3, rename = F,
file = file, file_refit = TRUE)
expect_error(expect_equal(mock_fit, mock_fit3))
})

0 comments on commit 6cfb7f9

Please sign in to comment.