diff --git a/DESCRIPTION b/DESCRIPTION index e0e2b682..7aa6589b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: bmm Title: Easy and Accesible Bayesian Measurement Models using 'brms' -Version: 0.4.3.9000 +Version: 0.4.4.9000 Authors@R: c( person("Vencislav", "Popov", , "vencislav.popov@gmail.com", role = c("aut", "cre", "cph")), person("Gidon", "Frischkorn", , "gidon.frischkorn@psychologie.uzh.ch", role = c("aut", "cph")), @@ -43,8 +43,7 @@ Imports: stats, matrixStats, crayon, - methods, - assertthat + methods URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/ BugReports: https://github.com/venpopov/bmm/issues Additional_repositories: diff --git a/NAMESPACE b/NAMESPACE index eed889e8..15d59be1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,6 +36,11 @@ S3method(postprocess_brm,default) S3method(postprocess_brm,sdmSimple) S3method(print,bmmsummary) S3method(print,message) +S3method(reset_env,bmmfit) +S3method(reset_env,bmmformula) +S3method(reset_env,brmsfamily) +S3method(reset_env,brmsformula) +S3method(reset_env,formula) S3method(revert_postprocess_brm,default) S3method(revert_postprocess_brm,sdmSimple) S3method(rhs_vars,bmmformula) @@ -84,6 +89,7 @@ export(qmixture3p) export(qsdm) export(rIMM) export(rad2deg) +export(restructure_bmm) export(revert_postprocess_brm) export(rmixture2p) export(rmixture3p) @@ -96,8 +102,8 @@ export(supported_models) export(use_model_template) export(wrap) import(stats) -importFrom(assertthat,assert_that) importFrom(brms,stancode) importFrom(brms,standata) importFrom(glue,glue) importFrom(magrittr,"%>%") +importFrom(utils,packageVersion) diff --git a/NEWS.md b/NEWS.md index 7460ff16..709d7673 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,7 @@ * add a custom summary() method for bmm models (#144) * add a global options bmm.summary_backend to control the backend used for the summary() method (choices are "bmm" and "brms") * deprecate get_model_prior(), get_stancode() and get_standata(). These functions will be removed in future versions of the package. Due to [recent changes](https://github.com/paul-buerkner/brms/pull/1604) in *brms* version 2.20.14, you can now use the *brms* functions `default_prior`, `stancode` and `standata` directly with *bmm* models (alternatively, their older aliases, "get_prior", "make_stancode", "make_standata"). +* function restructure() now allows to apply methods introduced in newer bmm versions to bmmfit objects created by older bmm versions # bmm 0.4.0 diff --git a/R/helpers-model.R b/R/helpers-model.R index 72987b94..bf9db383 100644 --- a/R/helpers-model.R +++ b/R/helpers-model.R @@ -427,36 +427,37 @@ use_model_template <- function(model_name, "# ?postprocess_brm for details\n\n") - model_object <- glue::glue(".model_<> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, ...) {\n", - " out <- list(\n", - " resp_vars = nlist(resp_var1),\n", - " other_vars = nlist(required_arg1, required_arg2),\n", - " info = list(\n", - " domain = '',\n", - " task = '',\n", - " name = '',\n", - " citation = '',\n", - " version = '',\n", - " requirements = '',\n", - " parameters = list(),\n", - " fixed_parameters = list()\n", - " ),\n", - " void_mu = FALSE\n", - " )\n", - " class(out) <- c('bmmmodel', '<>')\n", - " out\n", - "}\n\n", - .open = "<<", .close = ">>") + model_object <- glue(".model_<> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, links = NULL, ...) {\n", + " out <- list(\n", + " resp_vars = nlist(resp_var1),\n", + " other_vars = nlist(required_arg1, required_arg2),\n", + " domain = '',\n", + " task = '',\n", + " name = '',\n", + " citation = '',\n", + " version = '',\n", + " requirements = '',\n", + " parameters = list(),\n", + " links = list(),\n", + " fixed_parameters = list()\n", + " void_mu = FALSE\n", + " )\n", + " class(out) <- c('bmmmodel', '<>')\n", + " out$links[names(links)] <- links\n", + " out\n", + "}\n\n", + .open = "<<", .close = ">>") user_facing_alias <- glue::glue("# user facing alias\n", "# information in the title and details sections will be filled in\n", "# automatically based on the information in the .model_<>()$info\n \n", "#' @title `r .model_<>()$name`\n", "#' @name Model Name", - "#' @details `r model_info(model_<>())`\n", + "#' @details `r model_info(.model_<>())`\n", "#' @param resp_var1 A description of the response variable\n", "#' @param required_arg1 A description of the required argument\n", "#' @param required_arg2 A description of the required argument\n", + "#' @param links A list of links for the parameters.", "#' @param ... used internally for testing, ignore it\n", "#' @return An object of class `bmmmodel`\n", "#' @export\n", @@ -464,10 +465,10 @@ use_model_template <- function(model_name, "#' \\dontrun{\n", "#' # put a full example here (see 'R/bmm_model_mixture3p.R' for an example)\n", "#' }\n", - "<> <- function(resp_var1, required_arg1, required_arg2, ...) {\n", + "<> <- function(resp_var1, required_arg1, required_arg2, links = NULL, ...) {\n", " stop_missing_args()\n", " .model_<>(resp_var1 = resp_var1, required_arg1 = required_arg1,", - " required_arg2 = required_arg2, ...)\n", + " required_arg2 = required_arg2, links = links, ...)\n", "}\n\n", .open = "<<", .close = ">>") @@ -479,8 +480,7 @@ use_model_template <- function(model_name, " # check the data (required)\n\n\n", " # compute any necessary transformations (optional)\n\n", " # save some variables as attributes of the data for later use (optional)\n\n", - " data = NextMethod('check_data')\n\n", - " return(data)\n", + " NextMethod('check_data')\n\n", "}\n\n", .open = "<<", .close = ">>") @@ -505,7 +505,7 @@ use_model_template <- function(model_name, " brms_formula <- brms_formula + brms::lf(pform)\n", " }\n", " }\n\n", - " return(brms_formula)\n", + " brms_formula\n", "}\n\n", .open = "<<", .close = ">>") @@ -551,9 +551,9 @@ use_model_template <- function(model_name, } if (custom_family) { - out_template <- " out <- nlist(formula, data, family, prior, stanvars)\n" + out_template <- " nlist(formula, data, family, prior, stanvars)\n" } else { - out_template <- " out <- nlist(formula, data, family, prior)\n" + out_template <- " nlist(formula, data, family, prior)\n" } @@ -574,14 +574,13 @@ use_model_template <- function(model_name, " prior <- NULL\n\n", " # return the list\n", out_template, - " return(out)\n", "}\n\n", .open = "<<", .close = ">>") postprocess_brm_method <- glue::glue("#' @export\n", "postprocess_brm.<> <- function(model, fit) {\n", " # any required postprocessing (if none, delete this section)\n\n", - " return(fit)\n", + " fit\n", "}\n\n", .open = "<<", .close = ">>") diff --git a/R/helpers-postprocess.R b/R/helpers-postprocess.R index bfd37bd5..f316ff93 100644 --- a/R/helpers-postprocess.R +++ b/R/helpers-postprocess.R @@ -25,14 +25,12 @@ postprocess_brm <- function(model, fit, ...) { postprocess_brm.bmmmodel <- function(model, fit, ...) { dots <- list(...) class(fit) <- c('bmmfit','brmsfit') - fit$bmm$fit_args <- dots$fit_args fit$version$bmm <- utils::packageVersion('bmm') - fit$bmm$model <- model - fit$bmm$user_formula <- dots$user_formula - fit$bmm$configure_opts <- dots$configure_opts + fit$bmm <- nlist(model, user_formula = dots$user_formula, configure_opts = dots$configure_opts) attr(fit$data, 'data_name') <- attr(dots$fit_args$data, 'data_name') - - NextMethod('postprocess_brm') + fit <- NextMethod('postprocess_brm') + # clean up environments stored in the fit object + reset_env(fit) } #' @export diff --git a/R/helpers-prior.R b/R/helpers-prior.R index 236bbbe9..40ca6d7a 100644 --- a/R/helpers-prior.R +++ b/R/helpers-prior.R @@ -25,7 +25,7 @@ #' #' @name get_model_prior #' -#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}. +#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}. #' #' @keywords extract_info #' @@ -44,10 +44,13 @@ #' } #' @export get_model_prior <- function(object, data, model, formula = object, ...) { - if (utils::packageVersion('brms') >= "2.20.14") { - message("get_model_prior is deprecated. Please use get_prior() or default_prior()") - } else { - message("get_model_prior is deprecated. Please use get_prior() instead.") + fcall <- as.character(match.call()[1]) + if (fcall == "get_model_prior") { + if (utils::packageVersion('brms') >= "2.20.14") { + message("get_model_prior is deprecated. Please use get_prior() or default_prior()") + } else { + message("get_model_prior is deprecated. Please use get_prior() instead.") + } } if (missing(object) && !missing(formula)) { warning2("The 'formula' argument is deprecated for consistency with brms (>= 2.20.14).", diff --git a/R/restructure.R b/R/restructure.R index e6d6cca5..1d4ad72e 100644 --- a/R/restructure.R +++ b/R/restructure.R @@ -1,14 +1,35 @@ -#' @importFrom assertthat assert_that -restructure.bmm <- function(x) { - assert_that(is_bmmfit(x) | !is.null(x$version$bmm), msg = "Please provide a bmmfit object") +#' Restructure Old \code{bmmfit} Objects +#' +#' Restructure old \code{bmmfit} objects to work with +#' the latest \pkg{bmm} version. This function is called +#' internally when applying post-processing methods. +#' +#' @param x An object of class \code{bmmfit}. +#' @param ... Currently ignored. +#' +#' @return A \code{bmmfit} object compatible with the latest version +#' of \pkg{bmm} and \pkg{brms}. +#' @keywords transform +#' @export +#' @importFrom utils packageVersion +restructure_bmm <- function(x, ...) { version <- x$version$bmm if (is.null(version)) { - version <- as.package_version('0.1.1') + version <- as.package_version('0.2.1') + x$version$bmm <- version + } + if (!inherits(x, 'bmmfit')) { + class(x) <- c('bmmfit', class(x)) } - current_version <- utils::packageVersion('bmm') + current_version <- packageVersion('bmm') restr_version <- restructure_version.bmm(x) if (restr_version >= current_version) { + if (packageVersion("brms") >= "2.20.15") { + x <- NextMethod('restructure') + } else { + x <- brms::restructure(x) + } return(x) } @@ -25,8 +46,17 @@ restructure.bmm <- function(x) { x$bmm$user_formula <- assign_nl(x$bmm$user_formula) } + if (restr_version < "0.4.4") { + x$bmm$fit_args <- NULL + } + x$version$bmm_restructure <- current_version - brms::restructure(x) + if (packageVersion("brms") >= "2.20.15") { + x <- NextMethod('restructure') + } else { + x <- brms::restructure(x) + } + x } restructure_version.bmm <- function(x) { @@ -56,6 +86,26 @@ add_links.bmmmodel <- function(x) { } add_bmm_info <- function(x) { - # TODO: + env <- x$family$env + if (is.null(env)) { + stop2("Unable to restructure the object for use with the latest version of bmm. Please refit.") + } + pforms <- env$formula$pforms + names(pforms) <- NULL + user_formula <- brms::do_call("bmf", pforms) + model = env$model + model$resp_vars <- list(resp_err = env$formula$resp) + model$other_vars <- list() + if (inherits(model, 'sdmSimple')) { + model$info$parameters$mu <- glue('Location parameter of the SDM distribution \\ + (in radians; by default fixed internally to 0)') + } else { + model$info$parameters$mu1 = glue( + "Location parameter of the von Mises distribution for memory responses \\ + (in radians). Fixed internally to 0 by default." + ) + } + + x$bmm <- nlist(model, user_formula) x } diff --git a/R/summary.R b/R/summary.R index d467e90f..e34d1516 100644 --- a/R/summary.R +++ b/R/summary.R @@ -8,7 +8,11 @@ #' options(bmm.color_summary = FALSE) or bmm_options(color_summary = FALSE) #' @export summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE, mc_se = FALSE, ..., backend = 'bmm') { - object <- restructure.bmm(object) + if (packageVersion('brms') < '2.20.15') { + object <- restructure_bmm(object) + } else { + object <- brms::restructure(object) + } backend <- match.arg(backend, c('bmm', 'brms')) # get summary object from brms, since it contains a lot of necessary information: @@ -20,7 +24,6 @@ summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE, out <- rename_mu_smry(out, get_mu_pars(object)) # get the bmm specific information - bmmargs <- object$bmm$fit_args bmmmodel <- object$bmm$model bmmform <- object$bmm$user_formula diff --git a/R/update.R b/R/update.R index 9651b73d..4984c4a1 100644 --- a/R/update.R +++ b/R/update.R @@ -32,9 +32,12 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, .. stop2("You cannot update with a different model.\n", "If you want to use a different model, please use `fit_model()` instead.") } - object <- restructure.bmm(object) + if (packageVersion('brms') < '2.20.15') { + object <- restructure_bmm(object) + } else { + object <- brms::restructure(object) + } - fit_args <- object$bmm$fit_args model <- object$bmm$model old_user_formula <- object$bmm$user_formula olddata <- object$data @@ -75,7 +78,7 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, .. new_fit_args <- combine_args(nlist(config_args, dots)) # construct the new formula and data only if they have changed - if (!identical(new_fit_args$formula, fit_args$formula)) { + if (!identical(new_fit_args$formula, object$formula)) { formula. <- new_fit_args$formula } if (!identical(new_fit_args$data, olddata)) { diff --git a/R/utils.R b/R/utils.R index cf9aba2b..cc969205 100644 --- a/R/utils.R +++ b/R/utils.R @@ -519,3 +519,64 @@ tryCatch2 <- function(expr, capture = FALSE) { } +# resets the environments stored within an objects +reset_env <- function(object, env = NULL, ...) { + UseMethod("reset_env") +} + +#' @export +reset_env.bmmfit <- function(object, env = NULL, ...) { + if (is.null(env)) { + env <- globalenv() + } + object$formula <- reset_env(object$formula, env) + object$family <- reset_env(object$family, env) + object$bmm$user_formula <- reset_env(object$bmm$user_formula, env) + object +} + +#' @export +reset_env.bmmformula <- function(object, env = NULL, ...) { + if (is.null(env)) { + env <- globalenv() + } + for (par in names(object)) { + object[[par]] <- reset_env(object[[par]], env) + } + object +} + +#' @export +reset_env.brmsformula <- function(object, env = NULL, ...) { + if (is.null(env)) { + env <- globalenv() + } + object$formula <- reset_env(object$formula, env) + for (par in names(object$pforms)) { + object$pforms[[par]] <- reset_env(object$pforms[[par]], env) + } + if (!is.null(object$family)) { + object$family <- reset_env(object$family, env) + } + object +} + +#' @export +reset_env.formula <- function(object, env = NULL, ...) { + if (is.null(env)) { + env <- globalenv() + } + environment(object) <- env + object +} + +#' @export +reset_env.brmsfamily <- function(object, env = NULL, ...) { + if (is.null(env)) { + env <- globalenv() + } + if (!is.null(object$env)) { + object$env <- env + } + object +} diff --git a/R/zzz.R b/R/zzz.R index 8e42749d..78cb5ae5 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,8 +1,10 @@ .onLoad <- function(libname, pkgname) { suppressMessages(bmm_options(reset_options = TRUE)) - if (utils::packageVersion('brms') >= '2.20.14') { + if (utils::packageVersion('brms') >= '2.20.15') { registerS3method("default_prior", "bmmformula", get_model_prior, envir = asNamespace("brms")) + registerS3method("restructure", "bmmfit", restructure_bmm, + envir = asNamespace("brms")) } } diff --git a/_pkgdown.yml b/_pkgdown.yml index 162273ed..a0d28db7 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -50,7 +50,7 @@ reference: desc: "Functions for special distributions" - contents: - has_keyword("distribution") - - title: "Data and parameter transformations" + - title: "Data, model and parameter transformations" desc: "Utility functions for transforming data and parameters" - contents: - has_keyword("transform") diff --git a/man/restructure_bmm.Rd b/man/restructure_bmm.Rd new file mode 100644 index 00000000..edf21f42 --- /dev/null +++ b/man/restructure_bmm.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/restructure.R +\name{restructure_bmm} +\alias{restructure_bmm} +\title{Restructure Old \code{bmmfit} Objects} +\usage{ +restructure_bmm(x, ...) +} +\arguments{ +\item{x}{An object of class \code{bmmfit}.} + +\item{...}{Currently ignored.} +} +\value{ +A \code{bmmfit} object compatible with the latest version +of \pkg{bmm} and \pkg{brms}. +} +\description{ +Restructure old \code{bmmfit} objects to work with +the latest \pkg{bmm} version. This function is called +internally when applying post-processing methods. +} +\keyword{transform} diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R index 1e70182b..347493a5 100644 --- a/tests/testthat/test-fit_model.R +++ b/tests/testthat/test-fit_model.R @@ -13,8 +13,7 @@ test_that('Available mock models run without errors',{ f <- bmmformula(kappa ~ 1, thetat ~ 1) mock_fit <- fit_model(f, dat, mixture2p(resp_err = "resp_err"), backend="mock", mock_fit=1, rename=FALSE) expect_equal(mock_fit$fit, 1) - expect_type(mock_fit$bmm$fit_args, "list") - expect_equal(names(mock_fit$bmm$fit_args[1:4]), c("formula", "data", "family", "prior")) + expect_type(mock_fit$bmm, "list") # three-parameter model mock fit f <- bmmformula(kappa ~ 1, thetat ~ 1, thetant ~ 1) @@ -22,8 +21,7 @@ test_that('Available mock models run without errors',{ nt_features = paste0('Item',2:3,'_rel')), backend="mock", mock_fit=1, rename=FALSE) expect_equal(mock_fit$fit, 1) - expect_type(mock_fit$bmm$fit_args, "list") - expect_equal(names(mock_fit$bmm$fit_args[1:4]), c("formula", "data", "family", "prior")) + expect_type(mock_fit$bmm, "list") # IMMabc model mock fit f <- bmmformula(kappa ~ 1, c ~ 1, a ~ 1) @@ -31,23 +29,20 @@ test_that('Available mock models run without errors',{ nt_features = paste0('Item',2:3,'_rel')), backend="mock", mock_fit=1, rename=FALSE) expect_equal(mock_fit$fit, 1) - expect_type(mock_fit$bmm$fit_args, "list") - expect_equal(names(mock_fit$bmm$fit_args[1:4]), c("formula", "data", "family", "prior")) + expect_type(mock_fit$bmm, "list") # IMMbsc model mock fit f <- bmmformula(kappa ~ 1, c ~ 1, s ~ 1) mock_fit <- fit_model(f, dat, IMMbsc(resp_err = "resp_err", setsize=3, nt_features = paste0('Item',2:3,'_rel'), nt_distances=paste0('spaD',2:3)), backend="mock", mock_fit=1, rename=FALSE) expect_equal(mock_fit$fit, 1) - expect_type(mock_fit$bmm$fit_args, "list") - expect_equal(names(mock_fit$bmm$fit_args[1:4]), c("formula", "data", "family", "prior")) + expect_type(mock_fit$bmm, "list") # IMMfull model mock fit f <- bmmformula(kappa ~ 1, c ~ 1, a ~ 1, s ~ 1) mock_fit <- fit_model(f, dat, IMMfull(resp_err = "resp_err", setsize=3, nt_features = paste0('Item',2:3,'_rel'), nt_distances=paste0('spaD',2:3)), backend="mock", mock_fit=1, rename=FALSE) expect_equal(mock_fit$fit, 1) - expect_type(mock_fit$bmm$fit_args, "list") - expect_equal(names(mock_fit$bmm$fit_args[1:4]), c("formula", "data", "family", "prior")) + expect_type(mock_fit$bmm, "list") }) test_that('Available models produce expected errors', { diff --git a/tests/testthat/test-restructure.R b/tests/testthat/test-restructure.R index e072324a..caacda23 100644 --- a/tests/testthat/test-restructure.R +++ b/tests/testthat/test-restructure.R @@ -11,14 +11,8 @@ test_that("restructure works", { path <- test_path() file <- file.path(path, "../internal/ref_fits", "20240215_v0.2.1_mixture2p_seed-365_6ae900f5a4.rds") old_fit <- readRDS(file) - # TODO: this should be part of the restructure - old_fit$bmm$model <- structure(list(), class = c("bmmmodel", 'mixture2p')) - new_fit <- restructure.bmm(old_fit) + class(old_fit) <- c("bmmfit", class(old_fit)) + new_fit <- restructure_bmm(old_fit) expect_equal(new_fit$bmm$model$links,.model_mixture2p()$links) }) - -x <- c( '- The response vairable should be in radians and - represent the angular error relative to the target\n - - The non-target features should be in radians and be ') - diff --git a/tests/testthat/test-update.R b/tests/testthat/test-update.R index 3b54e8e3..c80750ac 100644 --- a/tests/testthat/test-update.R +++ b/tests/testthat/test-update.R @@ -6,7 +6,7 @@ test_that('update.bmmfit works', { # formula is replaced up <- update(fit1, formula. = bmf(c ~ 1, kappa ~ 1), testmode = TRUE) expect_true(is(up, "bmmfit")) - expect_equal(up$bmm$fit_args$formula$pforms$c, c ~ 1, ignore_attr = TRUE) + expect_equal(up$bmm$user_formula$c, c ~ 1, ignore_attr = TRUE) # data is replaced, old formula is kept new_data <- data @@ -15,13 +15,12 @@ test_that('update.bmmfit works', { testmode = TRUE) expect_true(is(up, "bmmfit")) expect_equal(attr(up$data, "data_name"), "new_data") - expect_equal(up$bmm$fit_args$formula$pforms$c, c ~ 0 + set_size, ignore_formula_env=T, ignore_attr = TRUE) + expect_equal(up$bmm$user_formula$c, c ~ 0 + set_size, ignore_formula_env=T, ignore_attr = TRUE) # prior is replaced up <- update(fit1, formula. = bmf(c ~ 1, kappa ~ 1), testmode = TRUE, prior = brms::set_prior("normal(0,0.1)", class="Intercept", dpar='kappa')) expect_true(is(up, "bmmfit")) - expect_equal(up$bmm$fit_args$prior$prior[3], "normal(0,0.1)") # refuse to change model expect_error(update(fit1, model = mixture2p(resp_err = "dev_rad")),