From 6cfb7f9466c2e5209a235aea6051b6f32b282d60 Mon Sep 17 00:00:00 2001 From: Ven Popov Date: Thu, 28 Mar 2024 20:07:02 +0100 Subject: [PATCH] add file argument to bmm() --- DESCRIPTION | 3 ++- R/bmm.R | 30 ++++++++++++++++++--- R/utils.R | 54 +++++++++++++++++++++++++++++++++++-- man/SDM.Rd | 2 +- man/bmm.Rd | 18 +++++++++++++ man/bmm_options.Rd | 4 +++ tests/testthat/test-utils.R | 53 ++++++++++++++++++++++++++++++++++++ 7 files changed, 157 insertions(+), 7 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 0e3d76a2..2bf0853c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -44,7 +44,8 @@ Imports: stats, matrixStats, crayon, - methods + methods, + fs URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/ BugReports: https://github.com/venpopov/bmm/issues Additional_repositories: diff --git a/R/bmm.R b/R/bmm.R index 1dd6d62f..8a56b6b1 100644 --- a/R/bmm.R +++ b/R/bmm.R @@ -34,6 +34,22 @@ #' "rstan" or "cmdstanr". If NULL (the default), "cmdstanr" will be used if #' the cmdstanr package is installed, otherwise "rstan" will be used. You can #' set the default backend using global `options(brms.backend = "rstan"/"cmdstanr")` +#' @param file Either `NULL` or a character string. If a string, the fitted +#' model object is saved via [saveRDS] in a file named after the string. The +#' `.rds extension is added automatically. If the file already exists, `bmm` +#' will load and return the saved model object. Unless you specify the +#' `file_refit` argument as well, the existing files won't be overwritten, you +#' have to manually remove the file in order to refit and save the model under +#' an existing file name. The file name is stored in the `bmmfit` object for +#' later usage. If the directory of the file does not exist, it will be created. +#' @param file_compress Logical or a character string, specifying one of the +#' compression algorithms supported by [saveRDS] when saving +#' the fitted model object. +#' @param file_refit Logical. Modifies when the fit stored via the `file` argument is +#' re-used. Can be set globally for the current \R session via the +#' `"bmm.file_refit"` option (see [options]). If `TRUE` (the default), the +#' model is re-used if the file exists. If `FALSE`, the model is re-fitted. Note +#' that unlike `brms`, there is no "on_change" option #' @param ... Further arguments passed to [brms::brm()] or Stan. See the #' description of [brms::brm()] for more details #' @@ -89,10 +105,15 @@ bmm <- function(formula, data, model, prior = NULL, sort_data = getOption('bmm.sort_data', "check"), silent = getOption('bmm.silent', 1), - backend = getOption('brms.backend', NULL), ...) { + backend = getOption('brms.backend', NULL), + file = NULL, file_compress = TRUE, + file_refit = getOption('bmm.file_refit', FALSE), ...) { deprecated_args(...) dots <- list(...) + x <- read_bmmfit(file, file_refit) + if (!is.null(x)) return(x) + # set temporary global options and return modified arguments for brms configure_opts <- nlist(sort_data, silent, backend, parallel = dots$parallel, cores = dots$cores) @@ -116,8 +137,11 @@ bmm <- function(formula, data, model, fit <- call_brm(fit_args) # model post-processing - postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula, - configure_opts = configure_opts) + fit <- postprocess_brm(model, fit, fit_args = fit_args, user_formula = user_formula, + configure_opts = configure_opts) + + # save the fitted model object if !is.null + save_bmmfit(fit, file, compress = file_compress) } diff --git a/R/utils.R b/R/utils.R index a303eae0..4a68be01 100644 --- a/R/utils.R +++ b/R/utils.R @@ -456,6 +456,8 @@ identical.formula <- function(x, y, ...) { #' printed in color. **Default: TRUE** #' @param reset_options logical. If TRUE, the options will be reset to their #' default values **Default: FALSE** +#' @param file_refit logical. If TRUE, bmm() will refit the model even if the +#' file argument is specified. **Default: FALSE** #' @details The `bmm_options` function is used to view or change the current bmm #' options. If no arguments are provided, the function will return the current #' options. If arguments are provided, the function will change the options @@ -493,7 +495,7 @@ identical.formula <- function(x, y, ...) { #' #' @export bmm_options <- function(sort_data, parallel, default_priors, silent, - color_summary, reset_options = FALSE) { + color_summary, file_refit, reset_options = FALSE) { opts <- ls() stopif(!missing(sort_data) && sort_data != "check" && !is.logical(sort_data), "sort_data must be one of TRUE, FALSE, or 'check'") @@ -505,6 +507,8 @@ bmm_options <- function(sort_data, parallel, default_priors, silent, "silent must be one of 0, 1, or 2") stopif(!missing(color_summary) && !is.logical(color_summary), "color_summary must be a logical value") + stopif(!missing(file_refit) && !is.logical(file_refit), + "file_refit must be a logical value") # set default options if function is called for the first time or if reset_options is TRUE if (reset_options) { @@ -512,7 +516,8 @@ bmm_options <- function(sort_data, parallel, default_priors, silent, bmm.parallel = FALSE, bmm.default_priors = TRUE, bmm.silent = 1, - bmm.color_summary = TRUE) + bmm.color_summary = TRUE, + bmm.file_refit = FALSE) } # change options if arguments are provided. get argument name and loop over non-missing arguments @@ -529,6 +534,7 @@ bmm_options <- function(sort_data, parallel, default_priors, silent, "\n parallel = ", getOption("bmm.parallel"), "\n default_priors = ", getOption("bmm.default_priors"), "\n silent = ", getOption("bmm.silent"), + "\n file_refit = ", getOption("bmm.file_refit"), "\n color_summary = ", getOption("bmm.color_summary"), "\n")), "For more information on these options or how to change them, see help(bmm_options).\n") invisible(old_op) @@ -663,6 +669,50 @@ deprecated_args <- function(...) { See `help("brm")` for more information.') } + +read_bmmfit <- function(file, file_refit) { + file <- check_rds_file(file) + if (is.null(file) || file_refit) { + return(NULL) + } + dir <- dirname(file) + dir <- try(fs::dir_create(dir)) + stopif(is_try_error(dir), "Cannot create directory for file.") + + out <- suppressWarnings(try(readRDS(file), silent = TRUE)) + if (!is_try_error(out)) { + if (!is_bmmfit(out)) { + stop2("Object loaded via 'file' is not of class 'bmmfit'.") + } + out$file <- file + } else { + out <- NULL + } + out +} + +save_bmmfit <- function(x, file, compress) { + file <- check_rds_file(file) + x$file <- file + if (!is.null(file)) { + saveRDS(x, file, compress = compress) + } + x +} + +check_rds_file <- function(file) { + if (is.null(file)) { + return(NULL) + } + stopif(!is.character(file), "'file' must be a character string.") + stopif(length(file) > 1, "'file' must be a single character string.") + ext <- fs::path_ext(file) + if (ext != "rds") { + file <- paste0(file, ".rds") + } + file +} + `%||%` <- function(a, b) { if (!is.null(a)) a else b } diff --git a/man/SDM.Rd b/man/SDM.Rd index 4102095e..a93e214b 100644 --- a/man/SDM.Rd +++ b/man/SDM.Rd @@ -40,7 +40,7 @@ and how to use it. * \strong{Domain:} Visual working memory \itemize{ \item Oberauer, K. (2023). Measurement models for visual working memory - A factorial model comparison. Psychological Review, 130(3), 841-852 } -\item \strong{Version:} Simple (no non-targets) +\item \strong{Version:} simple \item \strong{Requirements:} \itemize{ \item The response variable should be in radians and represent the angular error relative to the target diff --git a/man/bmm.Rd b/man/bmm.Rd index 0bde6b2c..decb4252 100644 --- a/man/bmm.Rd +++ b/man/bmm.Rd @@ -13,6 +13,9 @@ bmm( sort_data = getOption("bmm.sort_data", "check"), silent = getOption("bmm.silent", 1), backend = getOption("brms.backend", NULL), + file = NULL, + file_compress = TRUE, + file_refit = getOption("bmm.file_refit", FALSE), ... ) @@ -65,6 +68,21 @@ additional progress bars.} the cmdstanr package is installed, otherwise "rstan" will be used. You can set the default backend using global \code{options(brms.backend = "rstan"/"cmdstanr")}} +\item{file}{Either \code{NULL} or a character string. If a string, the fitted +model object is saved via \link{saveRDS} in a file named after the string. The +\verb{.rds extension is added automatically. If the file already exists, }bmm\verb{will load and return the saved model object. Unless you specify the}file_refit\verb{argument as well, the existing files won't be overwritten, you have to manually remove the file in order to refit and save the model under an existing file name. The file name is stored in the}bmmfit` object for +later usage. If the directory of the file does not exist, it will be created.} + +\item{file_compress}{Logical or a character string, specifying one of the +compression algorithms supported by \link{saveRDS} when saving +the fitted model object.} + +\item{file_refit}{Logical. Modifies when the fit stored via the \code{file} argument is +re-used. Can be set globally for the current \R session via the +\code{"bmm.file_refit"} option (see \link{options}). If \code{TRUE} (the default), the +model is re-used if the file exists. If \code{FALSE}, the model is re-fitted. Note +that unlike \code{brms}, there is no "on_change" option} + \item{...}{Further arguments passed to \code{\link[brms:brm]{brms::brm()}} or Stan. See the description of \code{\link[brms:brm]{brms::brm()}} for more details} } diff --git a/man/bmm_options.Rd b/man/bmm_options.Rd index e9e5640e..cce0aed1 100644 --- a/man/bmm_options.Rd +++ b/man/bmm_options.Rd @@ -10,6 +10,7 @@ bmm_options( default_priors, silent, color_summary, + file_refit, reset_options = FALSE ) } @@ -36,6 +37,9 @@ still printed. \strong{Default: 1}} \item{color_summary}{logical. If TRUE, the summary of the model will be printed in color. \strong{Default: TRUE}} +\item{file_refit}{logical. If TRUE, bmm() will refit the model even if the +file argument is specified. \strong{Default: FALSE}} + \item{reset_options}{logical. If TRUE, the options will be reset to their default values \strong{Default: FALSE}} } diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 801f04b5..511f4b9a 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -52,3 +52,56 @@ test_that("bmm_options works", { options(op) expect_equal(getOption('bmm.sort_data'), TRUE) }) + + +test_that("check_rds_file works", { + good_files <- list('a.rds', 'abc/a.rds', 'a', 'abc/a', 'a.M') + bad_files <- list(1, mean, c('a','b'), TRUE) + + for (f in good_files) { + expect_silent(res <- check_rds_file(f)) + expect_equal(fs::path_ext(res), 'rds') + } + + for (f in bad_files) { + expect_error(check_rds_file(f)) + } + + expect_null(check_rds_file(NULL)) +}) + +test_that("read_bmmfit works", { + mock_fit <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'), + backend = "mock", mock_fit = 1, rename = F) + file <- tempfile() + mock_fit$file <- paste0(file, '.rds') + saveRDS(mock_fit, paste0(file, '.rds')) + expect_equal(read_bmmfit(file, FALSE), mock_fit, ignore_function_env = TRUE, + ignore_formula_env = TRUE) + + x = 1 + saveRDS(x, paste0(file, '.rds')) + expect_error(read_bmmfit(file, FALSE), "not of class 'bmmfit'") +}) + +test_that("save_bmmfit works", { + file <- tempfile() + mock_fit <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'), + backend = "mock", mock_fit = 1, rename = F, + file = file) + rds_file <- paste0(file, '.rds') + expect_true(file.exists(rds_file)) + expect_equal(readRDS(rds_file), mock_fit, ignore_function_env = TRUE, + ignore_formula_env = TRUE) + + mock_fit2 <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'), + backend = "mock", mock_fit = 2, rename = F, + file = file) + expect_equal(mock_fit, mock_fit2) + + # they should not be the same if file_refit = TRUE + mock_fit3 <- bmm(bmf(c~1, kappa ~ 1), oberauer_lin_2017, sdm('dev_rad'), + backend = "mock", mock_fit = 3, rename = F, + file = file, file_refit = TRUE) + expect_error(expect_equal(mock_fit, mock_fit3)) +})