Skip to content

Commit

Permalink
Restructure function and internal clean-up (#151)
Browse files Browse the repository at this point in the history
* fix warning about get_model_prior deprecated
* cleanup bmmfit, make restructure into a method
* remove bmm$fit_args
* add functions for resetting formula environments to avoid storing unnecessary objects
  • Loading branch information
venpopov authored Mar 8, 2024
1 parent 9c34fef commit 804e103
Show file tree
Hide file tree
Showing 16 changed files with 216 additions and 80 deletions.
5 changes: 2 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: bmm
Title: Easy and Accesible Bayesian Measurement Models using 'brms'
Version: 0.4.3.9000
Version: 0.4.4.9000
Authors@R: c(
person("Vencislav", "Popov", , "[email protected]", role = c("aut", "cre", "cph")),
person("Gidon", "Frischkorn", , "[email protected]", role = c("aut", "cph")),
Expand Down Expand Up @@ -43,8 +43,7 @@ Imports:
stats,
matrixStats,
crayon,
methods,
assertthat
methods
URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/
BugReports: https://github.com/venpopov/bmm/issues
Additional_repositories:
Expand Down
8 changes: 7 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ S3method(postprocess_brm,default)
S3method(postprocess_brm,sdmSimple)
S3method(print,bmmsummary)
S3method(print,message)
S3method(reset_env,bmmfit)
S3method(reset_env,bmmformula)
S3method(reset_env,brmsfamily)
S3method(reset_env,brmsformula)
S3method(reset_env,formula)
S3method(revert_postprocess_brm,default)
S3method(revert_postprocess_brm,sdmSimple)
S3method(rhs_vars,bmmformula)
Expand Down Expand Up @@ -84,6 +89,7 @@ export(qmixture3p)
export(qsdm)
export(rIMM)
export(rad2deg)
export(restructure_bmm)
export(revert_postprocess_brm)
export(rmixture2p)
export(rmixture3p)
Expand All @@ -96,8 +102,8 @@ export(supported_models)
export(use_model_template)
export(wrap)
import(stats)
importFrom(assertthat,assert_that)
importFrom(brms,stancode)
importFrom(brms,standata)
importFrom(glue,glue)
importFrom(magrittr,"%>%")
importFrom(utils,packageVersion)
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* add a custom summary() method for bmm models (#144)
* add a global options bmm.summary_backend to control the backend used for the summary() method (choices are "bmm" and "brms")
* deprecate get_model_prior(), get_stancode() and get_standata(). These functions will be removed in future versions of the package. Due to [recent changes](https://github.com/paul-buerkner/brms/pull/1604) in *brms* version 2.20.14, you can now use the *brms* functions `default_prior`, `stancode` and `standata` directly with *bmm* models (alternatively, their older aliases, "get_prior", "make_stancode", "make_standata").
* function restructure() now allows to apply methods introduced in newer bmm versions to bmmfit objects created by older bmm versions

# bmm 0.4.0

Expand Down
59 changes: 29 additions & 30 deletions R/helpers-model.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,47 +427,48 @@ use_model_template <- function(model_name,
"# ?postprocess_brm for details\n\n")


model_object <- glue::glue(".model_<<model_name>> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, ...) {\n",
" out <- list(\n",
" resp_vars = nlist(resp_var1),\n",
" other_vars = nlist(required_arg1, required_arg2),\n",
" info = list(\n",
" domain = '',\n",
" task = '',\n",
" name = '',\n",
" citation = '',\n",
" version = '',\n",
" requirements = '',\n",
" parameters = list(),\n",
" fixed_parameters = list()\n",
" ),\n",
" void_mu = FALSE\n",
" )\n",
" class(out) <- c('bmmmodel', '<<model_name>>')\n",
" out\n",
"}\n\n",
.open = "<<", .close = ">>")
model_object <- glue(".model_<<model_name>> <- function(resp_var1 = NULL, required_arg1 = NULL, required_arg2 = NULL, links = NULL, ...) {\n",
" out <- list(\n",
" resp_vars = nlist(resp_var1),\n",
" other_vars = nlist(required_arg1, required_arg2),\n",
" domain = '',\n",
" task = '',\n",
" name = '',\n",
" citation = '',\n",
" version = '',\n",
" requirements = '',\n",
" parameters = list(),\n",
" links = list(),\n",
" fixed_parameters = list()\n",
" void_mu = FALSE\n",
" )\n",
" class(out) <- c('bmmmodel', '<<model_name>>')\n",
" out$links[names(links)] <- links\n",
" out\n",
"}\n\n",
.open = "<<", .close = ">>")

user_facing_alias <- glue::glue("# user facing alias\n",
"# information in the title and details sections will be filled in\n",
"# automatically based on the information in the .model_<<model_name>>()$info\n \n",
"#' @title `r .model_<<model_name>>()$name`\n",
"#' @name Model Name",
"#' @details `r model_info(model_<<model_name>>())`\n",
"#' @details `r model_info(.model_<<model_name>>())`\n",
"#' @param resp_var1 A description of the response variable\n",
"#' @param required_arg1 A description of the required argument\n",
"#' @param required_arg2 A description of the required argument\n",
"#' @param links A list of links for the parameters.",
"#' @param ... used internally for testing, ignore it\n",
"#' @return An object of class `bmmmodel`\n",
"#' @export\n",
"#' @examples\n",
"#' \\dontrun{\n",
"#' # put a full example here (see 'R/bmm_model_mixture3p.R' for an example)\n",
"#' }\n",
"<<model_name>> <- function(resp_var1, required_arg1, required_arg2, ...) {\n",
"<<model_name>> <- function(resp_var1, required_arg1, required_arg2, links = NULL, ...) {\n",
" stop_missing_args()\n",
" .model_<<model_name>>(resp_var1 = resp_var1, required_arg1 = required_arg1,",
" required_arg2 = required_arg2, ...)\n",
" required_arg2 = required_arg2, links = links, ...)\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand All @@ -479,8 +480,7 @@ use_model_template <- function(model_name,
" # check the data (required)\n\n\n",
" # compute any necessary transformations (optional)\n\n",
" # save some variables as attributes of the data for later use (optional)\n\n",
" data = NextMethod('check_data')\n\n",
" return(data)\n",
" NextMethod('check_data')\n\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand All @@ -505,7 +505,7 @@ use_model_template <- function(model_name,
" brms_formula <- brms_formula + brms::lf(pform)\n",
" }\n",
" }\n\n",
" return(brms_formula)\n",
" brms_formula\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand Down Expand Up @@ -551,9 +551,9 @@ use_model_template <- function(model_name,
}

if (custom_family) {
out_template <- " out <- nlist(formula, data, family, prior, stanvars)\n"
out_template <- " nlist(formula, data, family, prior, stanvars)\n"
} else {
out_template <- " out <- nlist(formula, data, family, prior)\n"
out_template <- " nlist(formula, data, family, prior)\n"
}


Expand All @@ -574,14 +574,13 @@ use_model_template <- function(model_name,
" prior <- NULL\n\n",
" # return the list\n",
out_template,
" return(out)\n",
"}\n\n",
.open = "<<", .close = ">>")

postprocess_brm_method <- glue::glue("#' @export\n",
"postprocess_brm.<<model_name>> <- function(model, fit) {\n",
" # any required postprocessing (if none, delete this section)\n\n",
" return(fit)\n",
" fit\n",
"}\n\n",
.open = "<<", .close = ">>")

Expand Down
10 changes: 4 additions & 6 deletions R/helpers-postprocess.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ postprocess_brm <- function(model, fit, ...) {
postprocess_brm.bmmmodel <- function(model, fit, ...) {
dots <- list(...)
class(fit) <- c('bmmfit','brmsfit')
fit$bmm$fit_args <- dots$fit_args
fit$version$bmm <- utils::packageVersion('bmm')
fit$bmm$model <- model
fit$bmm$user_formula <- dots$user_formula
fit$bmm$configure_opts <- dots$configure_opts
fit$bmm <- nlist(model, user_formula = dots$user_formula, configure_opts = dots$configure_opts)
attr(fit$data, 'data_name') <- attr(dots$fit_args$data, 'data_name')

NextMethod('postprocess_brm')
fit <- NextMethod('postprocess_brm')
# clean up environments stored in the fit object
reset_env(fit)
}

#' @export
Expand Down
13 changes: 8 additions & 5 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#'
#' @name get_model_prior
#'
#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}.
#' @seealso [supported_models()], \code{\link[brms:get_prior]{brms::get_prior()}}.
#'
#' @keywords extract_info
#'
Expand All @@ -44,10 +44,13 @@
#' }
#' @export
get_model_prior <- function(object, data, model, formula = object, ...) {
if (utils::packageVersion('brms') >= "2.20.14") {
message("get_model_prior is deprecated. Please use get_prior() or default_prior()")
} else {
message("get_model_prior is deprecated. Please use get_prior() instead.")
fcall <- as.character(match.call()[1])
if (fcall == "get_model_prior") {
if (utils::packageVersion('brms') >= "2.20.14") {
message("get_model_prior is deprecated. Please use get_prior() or default_prior()")
} else {
message("get_model_prior is deprecated. Please use get_prior() instead.")
}
}
if (missing(object) && !missing(formula)) {
warning2("The 'formula' argument is deprecated for consistency with brms (>= 2.20.14).",
Expand Down
64 changes: 57 additions & 7 deletions R/restructure.R
Original file line number Diff line number Diff line change
@@ -1,14 +1,35 @@
#' @importFrom assertthat assert_that
restructure.bmm <- function(x) {
assert_that(is_bmmfit(x) | !is.null(x$version$bmm), msg = "Please provide a bmmfit object")
#' Restructure Old \code{bmmfit} Objects
#'
#' Restructure old \code{bmmfit} objects to work with
#' the latest \pkg{bmm} version. This function is called
#' internally when applying post-processing methods.
#'
#' @param x An object of class \code{bmmfit}.
#' @param ... Currently ignored.
#'
#' @return A \code{bmmfit} object compatible with the latest version
#' of \pkg{bmm} and \pkg{brms}.
#' @keywords transform
#' @export
#' @importFrom utils packageVersion
restructure_bmm <- function(x, ...) {
version <- x$version$bmm
if (is.null(version)) {
version <- as.package_version('0.1.1')
version <- as.package_version('0.2.1')
x$version$bmm <- version
}
if (!inherits(x, 'bmmfit')) {
class(x) <- c('bmmfit', class(x))
}
current_version <- utils::packageVersion('bmm')
current_version <- packageVersion('bmm')
restr_version <- restructure_version.bmm(x)

if (restr_version >= current_version) {
if (packageVersion("brms") >= "2.20.15") {
x <- NextMethod('restructure')
} else {
x <- brms::restructure(x)
}
return(x)
}

Expand All @@ -25,8 +46,17 @@ restructure.bmm <- function(x) {
x$bmm$user_formula <- assign_nl(x$bmm$user_formula)
}

if (restr_version < "0.4.4") {
x$bmm$fit_args <- NULL
}

x$version$bmm_restructure <- current_version
brms::restructure(x)
if (packageVersion("brms") >= "2.20.15") {
x <- NextMethod('restructure')
} else {
x <- brms::restructure(x)
}
x
}

restructure_version.bmm <- function(x) {
Expand Down Expand Up @@ -56,6 +86,26 @@ add_links.bmmmodel <- function(x) {
}

add_bmm_info <- function(x) {
# TODO:
env <- x$family$env
if (is.null(env)) {
stop2("Unable to restructure the object for use with the latest version of bmm. Please refit.")
}
pforms <- env$formula$pforms
names(pforms) <- NULL
user_formula <- brms::do_call("bmf", pforms)
model = env$model
model$resp_vars <- list(resp_err = env$formula$resp)
model$other_vars <- list()
if (inherits(model, 'sdmSimple')) {
model$info$parameters$mu <- glue('Location parameter of the SDM distribution \\
(in radians; by default fixed internally to 0)')
} else {
model$info$parameters$mu1 = glue(
"Location parameter of the von Mises distribution for memory responses \\
(in radians). Fixed internally to 0 by default."
)
}

x$bmm <- nlist(model, user_formula)
x
}
7 changes: 5 additions & 2 deletions R/summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
#' options(bmm.color_summary = FALSE) or bmm_options(color_summary = FALSE)
#' @export
summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE, mc_se = FALSE, ..., backend = 'bmm') {
object <- restructure.bmm(object)
if (packageVersion('brms') < '2.20.15') {
object <- restructure_bmm(object)
} else {
object <- brms::restructure(object)
}
backend <- match.arg(backend, c('bmm', 'brms'))

# get summary object from brms, since it contains a lot of necessary information:
Expand All @@ -20,7 +24,6 @@ summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE,
out <- rename_mu_smry(out, get_mu_pars(object))

# get the bmm specific information
bmmargs <- object$bmm$fit_args
bmmmodel <- object$bmm$model
bmmform <- object$bmm$user_formula

Expand Down
9 changes: 6 additions & 3 deletions R/update.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, ..
stop2("You cannot update with a different model.\n",
"If you want to use a different model, please use `fit_model()` instead.")
}
object <- restructure.bmm(object)
if (packageVersion('brms') < '2.20.15') {
object <- restructure_bmm(object)
} else {
object <- brms::restructure(object)
}

fit_args <- object$bmm$fit_args
model <- object$bmm$model
old_user_formula <- object$bmm$user_formula
olddata <- object$data
Expand Down Expand Up @@ -75,7 +78,7 @@ update.bmmfit <- function(object, formula., newdata = NULL, recompile = NULL, ..
new_fit_args <- combine_args(nlist(config_args, dots))

# construct the new formula and data only if they have changed
if (!identical(new_fit_args$formula, fit_args$formula)) {
if (!identical(new_fit_args$formula, object$formula)) {
formula. <- new_fit_args$formula
}
if (!identical(new_fit_args$data, olddata)) {
Expand Down
Loading

0 comments on commit 804e103

Please sign in to comment.