diff --git a/DESCRIPTION b/DESCRIPTION index c6e3b4e8..6e396245 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: bmm Title: Easy and Accesible Bayesian Measurement Models using 'brms' -Version: 0.3.4.9000 +Version: 0.3.5.9000 Authors@R: c( person("Vencislav", "Popov", , "vencislav.popov@gmail.com", role = c("aut", "cre", "cph")), person("Gidon", "Frischkorn", , "gidon.frischkorn@psychologie.uzh.ch", role = c("aut", "cph"))) diff --git a/NAMESPACE b/NAMESPACE index 4c4b7e17..9ee3baa7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,8 @@ S3method(check_data,vwm) S3method(check_formula,bmmmodel) S3method(check_formula,default) S3method(check_formula,nontargets) +S3method(check_model,bmmmodel) +S3method(check_model,default) S3method(configure_model,IMMabc) S3method(configure_model,IMMbsc) S3method(configure_model,IMMfull) @@ -25,6 +27,7 @@ S3method(postprocess_brm,bmmmodel) S3method(postprocess_brm,default) S3method(postprocess_brm,sdmSimple) S3method(postprocess_brm,vwm) +S3method(print,message) S3method(rhs_vars,bmmformula) S3method(rhs_vars,formula) S3method(update,bmmfit) diff --git a/NEWS.md b/NEWS.md index b3264b97..49fde50b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,7 @@ * add postprocessing methods for sdmSimple to allow for pp_check(), conditional_effects and bridgesampling usage with the model (#30) * add informed default priors for all models. You can always use the `get_model_prior()` function to see the default priors for a model * add a new function `set_default_prior` for developers, which allows them to more easily set default priors on new models regardless of the user-specified formula +* you can now specify variables for models via regular expressions rather than character vectors [#102] ### Bug fixes * fix a bug in the mixture3p and IMM models which caused an error when intercept was not supressed and set size was used as predictor diff --git a/R/bmm_model_IMM.R b/R/bmm_model_IMM.R index babaa50e..bd2151f1 100644 --- a/R/bmm_model_IMM.R +++ b/R/bmm_model_IMM.R @@ -2,7 +2,7 @@ # MODELS #### #############################################################################! -.model_IMMabc <- function(resp_err, nt_features, setsize, ...) { +.model_IMMabc <- function(resp_err, nt_features, setsize, regex = FALSE, ...) { out <- list( resp_vars = nlist(resp_err), other_vars = nlist(nt_features, setsize), @@ -29,11 +29,13 @@ )), void_mu = FALSE ) + attr(out, "regex") <- regex + attr(out, "regex_vars") <- c('nt_features') # variables that can be specified via regular expression class(out) <- c("bmmmodel", "vwm","nontargets","IMMabc") out } -.model_IMMbsc <- function(resp_err, nt_features, nt_distances, setsize, ...) { +.model_IMMbsc <- function(resp_err, nt_features, nt_distances, setsize, regex = FALSE, ...) { out <- list( resp_vars = nlist(resp_err), other_vars = nlist(nt_features, nt_distances, setsize), @@ -60,11 +62,14 @@ )), void_mu = FALSE ) + attr(out, "regex") <- regex + # variables that can be specified via regular expression + attr(out, "regex_vars") <- c('nt_features', 'nt_distances') class(out) <- c("bmmmodel","vwm","nontargets","IMMspatial","IMMbsc") out } -.model_IMMfull <- function(resp_err, nt_features, nt_distances, setsize, ...) { +.model_IMMfull <- function(resp_err, nt_features, nt_distances, setsize, regex = FALSE, ...) { out <- list( resp_vars = nlist(resp_err), other_vars = nlist(nt_features, nt_distances, setsize), @@ -92,6 +97,9 @@ )), void_mu = FALSE ) + attr(out, "regex") <- regex + # variables that can be specified via regular expression + attr(out, "regex_vars") <- c('nt_features', 'nt_distances') class(out) <- c("bmmmodel","vwm","nontargets","IMMspatial","IMMfull") out } @@ -114,20 +122,71 @@ #' #' - b = "Background activation (internally fixed to 0)" #' -#' @param resp_err The name of the variable in the provided dataset containing the -#' response error. The response Error should code the response relative to the to-be-recalled -#' target in radians. You can transform the response error in degrees to radian using the `deg2rad` function. -#' @param nt_features A character vector with the names of the non-target variables. -#' The non_target variables should be in radians and be centered relative to the -#' target. -#' @param nt_distances A vector of names of the columns containing the distances of -#' non-target items to the target item. Only necessary for the `IMMbsc` and `IMMfull` models +#' @param resp_err The name of the variable in the provided dataset containing +#' the response error. The response Error should code the response relative to +#' the to-be-recalled target in radians. You can transform the response error +#' in degrees to radian using the `deg2rad` function. +#' @param nt_features A character vector with the names of the non-target +#' variables. The non_target variables should be in radians and be centered +#' relative to the target. Alternatively, if regex=TRUE, a regular +#' expression can be used to match the non-target feature columns in the +#' dataset. +#' @param nt_distances A vector of names of the columns containing the distances +#' of non-target items to the target item. Alternatively, if regex=TRUE, a regular +#' expression can be used to match the non-target distances columns in the +#' dataset. Only necessary for the `IMMbsc` and `IMMfull` models. #' @param setsize Name of the column containing the set size variable (if #' setsize varies) or a numeric value for the setsize, if the setsize is #' fixed. +#' @param regex Logical. If TRUE, the `nt_features` and `nt_distances` arguments +#' are interpreted as a regular expression to match the non-target feature +#' columns in the dataset. #' @param ... used internally for testing, ignore it #' @return An object of class `bmmmodel` #' @keywords bmmmodel +#' @examples +#' \dontrun{ +#' # load data +#' data <- OberauerLin_2017 +#' +#' # define formula +#' ff <- bmmformula( +#' kappa ~ 0 + set_size, +#' c ~ 0 + set_size, +#' a ~ 0 + set_size, +#' s ~ 0 + set_size +#' ) +#' +#' # specify the full IMM model with explicit column names for non-target features and distances +#' model1 <- IMMfull(resp_err = "dev_rad", +#' nt_features = paste0('col_nt',1:7), +#' nt_distances = paste0('dist_nt',1:7), +#' setsize = 'set_size') +#' +#' # fit the model +#' fit <- fit_model(formula = ff, +#' data = data, +#' model = model1, +#' parallel = T, +#' iter = 500, +#' backend = 'cmdstanr') +#' +#' # alternatively specify the IMM model with a regular expression to match non-target features +#' # this is equivalent to the previous call, but more concise +#' model2 <- IMMfull(resp_err = "dev_rad", +#' nt_features = 'col_nt', +#' nt_distances = 'dist_nt', +#' setsize = 'set_size', +#' regex = TRUE) +#' +#' # fit the model +#' fit <- fit_model(formula = ff, +#' data = data, +#' model = model2, +#' parallel=T, +#' iter = 500, +#' backend='cmdstanr') +#'} #' @export IMMfull <- .model_IMMfull diff --git a/R/bmm_model_mixture3p.R b/R/bmm_model_mixture3p.R index 34f23641..18d19df6 100644 --- a/R/bmm_model_mixture3p.R +++ b/R/bmm_model_mixture3p.R @@ -2,7 +2,7 @@ # MODELS #### #############################################################################! -.model_mixture3p <- function(resp_err, nt_features, setsize, ...) { +.model_mixture3p <- function(resp_err, nt_features, setsize, regex = FALSE, ...) { out <- list( resp_vars = nlist(resp_err), other_vars = nlist(nt_features, setsize), @@ -30,6 +30,8 @@ )), void_mu = FALSE ) + attr(out, "regex") <- regex + attr(out, "regex_vars") <- c('nt_features') # variables that can be specified via regular expression class(out) = c("bmmmodel", "vwm", "nontargets", "mixture3p") out } @@ -43,11 +45,15 @@ #' the to-be-recalled target in radians. You can transform the response error #' in degrees to radians using the `deg2rad` function. #' @param nt_features A character vector with the names of the non-target -#' feature values. The non_target feature values should be in radians and centered -#' relative to the target. +#' feature values. The non_target feature values should be in radians and +#' centered relative to the target. Alternatively, if regex=TRUE, a regular +#' expression can be used to match the non-target feature columns in the +#' dataset. #' @param setsize Name of the column containing the set size variable (if #' setsize varies) or a numeric value for the setsize, if the setsize is #' fixed. +#' @param regex Logical. If TRUE, the `nt_features` argument is interpreted as +#' a regular expression to match the non-target feature columns in the dataset. #' @param ... used internally for testing, ignore it #' @return An object of class `bmmmodel` #' @keywords bmmmodel @@ -69,13 +75,25 @@ #' thetant ~ 1 #' ) #' -#' # specify the 3-parameter model -#' model <- mixture3p(resp_err = "y", nt_features = paste0('nt',1:3,'_loc'), setsize = 4) +#' # specify the 3-parameter model with explicit column names for non-target features +#' model1 <- mixture3p(resp_err = "y", nt_features = paste0('nt',1:3,'_loc'), setsize = 4) #' #' # fit the model #' fit <- fit_model(formula = ff, #' data = dat, -#' model = model, +#' model = model1, +#' parallel=T, +#' iter = 500, +#' backend='cmdstanr') +#' +#' # alternatively specify the 3-parameter model with a regular expression to match non-target features +#' # this is equivalent to the previous call, but more concise +#' model2 <- mixture3p(resp_err = "y", nt_features = "nt.*_loc", setsize = 4, regex = TRUE) +#' +#' # fit the model +#' fit <- fit_model(formula = ff, +#' data = dat, +#' model = model2, #' parallel=T, #' iter = 500, #' backend='cmdstanr') diff --git a/R/fit_model.R b/R/fit_model.R index 1b4364d0..3d53de90 100644 --- a/R/fit_model.R +++ b/R/fit_model.R @@ -89,7 +89,7 @@ fit_model <- function(formula, data, model, parallel = FALSE, chains = 4, opts <- configure_options(nlist(parallel, chains, sort_data, silent)) # check model, formula and data, and transform data if necessary - model <- check_model(model) + model <- check_model(model, data) data <- check_data(model, data, formula) formula <- check_formula(model, data, formula) diff --git a/R/helpers-data.R b/R/helpers-data.R index 775107ee..f03e21ac 100644 --- a/R/helpers-data.R +++ b/R/helpers-data.R @@ -276,7 +276,7 @@ rad2deg <- function(rad){ get_standata <- function(formula, data, model, prior=NULL, ...) { # check model, formula and data, and transform data if necessary - model <- check_model(model) + model <- check_model(model, data) data <- check_data(model, data, formula) formula <- check_formula(model, data, formula) diff --git a/R/helpers-model.R b/R/helpers-model.R index b2e5e358..84b77a74 100644 --- a/R/helpers-model.R +++ b/R/helpers-model.R @@ -104,6 +104,84 @@ configure_model <- function(model, data, formula) { UseMethod("configure_model") } +#############################################################################! +# CHECK_MODEL methods #### +#############################################################################! +#' Checks if the model is supported, and returns the model function +#' @param model the model argument supplied by the user +#' @param data the data argument supplied by the user +#' @return A list generated by a model function of type .model_* +#' @keywords internal, developer +check_model <- function(model, data=NULL) { + UseMethod("check_model") +} + +#' @export +check_model.default <- function(model, data) { + bmm_models <- supported_models(print_call=FALSE) + if (is.function(model)) { + fun_name <- as.character(substitute(model)) + if (fun_name %in% bmm_models) { + stop2("Did you forget to provide the required arguments to the model function?\n", + "See ?", fun_name, " for details on properly specifying the model argument\n\n") + } + } + if(!is_supported_bmmmodel(model)) { + stop2("You provided an object of class `", class(model), "` to the model argument. ", + "The model argument should be a `bmmmodel` function.\n", + "You can see the list of supported models by running `supported_models()`\n\n", + supported_models()) + } + model +} + +#' @export +check_model.bmmmodel <- function(model, data = NULL) { + model <- replace_regex_variables(model, data) + NextMethod("check_model") +} + + + +# check if the user has provided a regular expression for any model variables and +# replace the regular expression with the actual variables +replace_regex_variables <- function(model, data) { + regex <- isTRUE(attr(model, "regex")) + regex_vars <- attr(model, "regex_vars") + + # check if the regex transformation has already been applied (e.g., if + # updating a previously fit model) + regex_applied <- isTRUE(attr(model, "regex_applied")) + + if (!regex_applied) { + data_cols <- names(data) + # save original user-provided variables + user_vars <- c(model$resp_vars, model$other_vars) + attr(model, "user_vars") <- user_vars + + if (regex && length(regex_vars) > 0) { + for (var in regex_vars) { + if (var %in% names(model$other_vars)) { + model$other_vars[[var]] <- get_variables(model$other_vars[[var]], + data_cols, + regex) + } + if (var %in% model$resp_vars) { + model$resp_vars[[var]] <- get_variables(model$resp_vars[[var]], + data_cols, + regex) + } + } + attr(model, "regex_applied") <- regex + } + } + + model +} + + + + #############################################################################! # HELPER FUNCTIONS #### #############################################################################! @@ -122,22 +200,27 @@ configure_model <- function(model, data, formula) { supported_models <- function(print_call=TRUE) { supported_models <- lsp("bmm", pattern = "^\\.model_") supported_models <- sub("^\\.model_", "", supported_models) - if (print_call) { - out <- "The following models are supported:\n\n" - for (model in supported_models) { - args <- methods::formalArgs(get(model)) - args <- args[!args %in% c("...")] - args <- collapse_comma(args) - args <- gsub("'", "", args) - out <- paste0(out, '- `', model,'(',args,')`', "\n", sep='') - } - out <- paste0(out, "\nType `?modelname` to get information about a specific model, e.g. `?IMMfull`\n") - cat(gsub("`", " ", out)) - return(invisible(out)) + if (!print_call) { + return(supported_models) + } + + out <- "The following models are supported:\n\n" + for (model in supported_models) { + args <- methods::formalArgs(get(model)) + args <- args[!args %in% c("...")] + args <- collapse_comma(args) + args <- gsub("'", "", args) + out <- paste0(out, '- `', model,'(',args,')`', "\n", sep='') } - supported_models + out <- paste0(out, "\nType `?modelname` to get information about a specific model, e.g. `?IMMfull`\n") + out <- gsub("`", " ", out) + class(out) <- "message" + out } + + + #' @title Generate a markdown list of the measurement models available in `bmm` #' @description Used internally to automatically populate information in the README file #' @return Markdown code for printing the list of measurement models available in `bmm` @@ -214,20 +297,6 @@ model_info.bmmmodel <- function(model, components = 'all') { collapse(info_all[components]) } -#' Checks if the model is supported, and returns the model function -#' @param model A string with the name of the model supplied by the user -#' @return A list generated by a model function of type .model_* -#' @keywords internal, developer -check_model <- function(model) { - model_label <- class(model)[length(class(model))] - ok_models <- supported_models(print_call=FALSE) - if (not_in(model_label, ok_models)) { - stop(model_label, " is not a supported model. Supported ", - "models are:\n", collapse_comma(ok_models)) - } - model -} - #' @param model A string with the name of the model supplied by the user @@ -581,7 +650,7 @@ use_model_template <- function(model_name, get_stancode <- function(formula, data, model, prior=NULL, ...) { # check model, formula and data, and transform data if necessary - model <- check_model(model) + model <- check_model(model,data) data <- check_data(model, data, formula) formula <- check_formula(model, data, formula) diff --git a/R/helpers-prior.R b/R/helpers-prior.R index e5b63dc4..be1049de 100644 --- a/R/helpers-prior.R +++ b/R/helpers-prior.R @@ -65,7 +65,7 @@ combine_prior <- function(prior1, prior2) { #' } #' get_model_prior <- function(formula, data, model, ...) { - model <- check_model(model) + model <- check_model(model, data) data <- check_data(model, data, formula) formula <- check_formula(model, data, formula) config_args <- configure_model(model, data, formula) diff --git a/R/utils.R b/R/utils.R index 0d15cf4e..dcf056a5 100644 --- a/R/utils.R +++ b/R/utils.R @@ -226,6 +226,15 @@ is_try_warning <- function(x) { inherits(x, "warning") } +is_bmmmodel <- function(x) { + inherits(x, "bmmmodel") +} + +is_supported_bmmmodel <- function(x) { + valid_models <- supported_models(print_call = FALSE) + is_bmmmodel(x) && inherits(x, valid_models) +} + as_numeric_vector <- function(x) { out <- tryCatch(as.numeric(as.character(x)), warning = function(w) w) if (is_try_warning(out)) { @@ -308,3 +317,27 @@ order_data_query <- function(model, data, formula) { #' `save_pars`. For details see ?brms::save_pars. #' @export save_pars <- brms::save_pars + + +# custom method form printing nicely formatted character values via cat instead of print +#' @export +print.message <- function(x, ...) { + cat(x, ...) +} + + +# returns either x, or all variables that match the regular expression x +# @param x character vector or regular expression +# @param all_variables character vector of all variables within which to search +# @param regex logical. If TRUE, x is treated as a regular expression +get_variables <- function(x, all_variables, regex = FALSE) { + if (regex) { + variables <- all_variables[grep(x, all_variables)] + if (length(variables) == 0) { + stop2("No variables found that match the regular expression '", x, "'") + } + return(variables) + } + x +} + diff --git a/man/IMM.Rd b/man/IMM.Rd index 363836c7..93dc2bc2 100644 --- a/man/IMM.Rd +++ b/man/IMM.Rd @@ -7,28 +7,37 @@ \alias{IMMabc} \title{Interference measurement model by Oberauer and Lin (2017).} \usage{ -IMMfull(resp_err, nt_features, nt_distances, setsize, ...) +IMMfull(resp_err, nt_features, nt_distances, setsize, regex = FALSE, ...) -IMMbsc(resp_err, nt_features, nt_distances, setsize, ...) +IMMbsc(resp_err, nt_features, nt_distances, setsize, regex = FALSE, ...) -IMMabc(resp_err, nt_features, setsize, ...) +IMMabc(resp_err, nt_features, setsize, regex = FALSE, ...) } \arguments{ -\item{resp_err}{The name of the variable in the provided dataset containing the -response error. The response Error should code the response relative to the to-be-recalled -target in radians. You can transform the response error in degrees to radian using the \code{deg2rad} function.} +\item{resp_err}{The name of the variable in the provided dataset containing +the response error. The response Error should code the response relative to +the to-be-recalled target in radians. You can transform the response error +in degrees to radian using the \code{deg2rad} function.} -\item{nt_features}{A character vector with the names of the non-target variables. -The non_target variables should be in radians and be centered relative to the -target.} +\item{nt_features}{A character vector with the names of the non-target +variables. The non_target variables should be in radians and be centered +relative to the target. Alternatively, if regex=TRUE, a regular +expression can be used to match the non-target feature columns in the +dataset.} -\item{nt_distances}{A vector of names of the columns containing the distances of -non-target items to the target item. Only necessary for the \code{IMMbsc} and \code{IMMfull} models} +\item{nt_distances}{A vector of names of the columns containing the distances +of non-target items to the target item. Alternatively, if regex=TRUE, a regular +expression can be used to match the non-target distances columns in the +dataset. Only necessary for the \code{IMMbsc} and \code{IMMfull} models.} \item{setsize}{Name of the column containing the set size variable (if setsize varies) or a numeric value for the setsize, if the setsize is fixed.} +\item{regex}{Logical. If TRUE, the \code{nt_features} and \code{nt_distances} arguments +are interpreted as a regular expression to match the non-target feature +columns in the dataset.} + \item{...}{used internally for testing, ignore it} } \value{ @@ -118,4 +127,48 @@ included in the model formula. The parameter is: } } } +\examples{ +\dontrun{ +# load data +data <- OberauerLin_2017 + +# define formula +ff <- bmmformula( + kappa ~ 0 + set_size, + c ~ 0 + set_size, + a ~ 0 + set_size, + s ~ 0 + set_size +) + +# specify the full IMM model with explicit column names for non-target features and distances +model1 <- IMMfull(resp_err = "dev_rad", + nt_features = paste0('col_nt',1:7), + nt_distances = paste0('dist_nt',1:7), + setsize = 'set_size') + +# fit the model +fit <- fit_model(formula = ff, + data = data, + model = model1, + parallel = T, + iter = 500, + backend = 'cmdstanr') + +# alternatively specify the IMM model with a regular expression to match non-target features +# this is equivalent to the previous call, but more concise +model2 <- IMMfull(resp_err = "dev_rad", + nt_features = 'col_nt', + nt_distances = 'dist_nt', + setsize = 'set_size', + regex = TRUE) + +# fit the model +fit <- fit_model(formula = ff, + data = data, + model = model2, + parallel=T, + iter = 500, + backend='cmdstanr') +} +} \keyword{bmmmodel} diff --git a/man/check_model.Rd b/man/check_model.Rd index dbacce23..f8fdc66c 100644 --- a/man/check_model.Rd +++ b/man/check_model.Rd @@ -4,10 +4,12 @@ \alias{check_model} \title{Checks if the model is supported, and returns the model function} \usage{ -check_model(model) +check_model(model, data = NULL) } \arguments{ -\item{model}{A string with the name of the model supplied by the user} +\item{model}{the model argument supplied by the user} + +\item{data}{the data argument supplied by the user} } \value{ A list generated by a model function of type .model_* diff --git a/man/fit_model.Rd b/man/fit_model.Rd index 41a70d4e..671f9b4a 100644 --- a/man/fit_model.Rd +++ b/man/fit_model.Rd @@ -74,15 +74,15 @@ model. \details{ The following models are supported: \itemize{ -\item \code{IMMabc(resp_err, nt_features, setsize)} -\item \code{IMMbsc(resp_err, nt_features, nt_distances, setsize)} -\item \code{IMMfull(resp_err, nt_features, nt_distances, setsize)} -\item \code{mixture2p(resp_err)} -\item \code{mixture3p(resp_err, nt_features, setsize)} -\item \code{sdmSimple(resp_err)} +\item IMMabc(resp_err, nt_features, setsize, regex) +\item IMMbsc(resp_err, nt_features, nt_distances, setsize, regex) +\item IMMfull(resp_err, nt_features, nt_distances, setsize, regex) +\item mixture2p(resp_err) +\item mixture3p(resp_err, nt_features, setsize, regex) +\item sdmSimple(resp_err) } -Type \code{?modelname} to get information about a specific model, e.g. \code{?IMMfull} +Type ?modelname to get information about a specific model, e.g. ?IMMfull Type \code{help(package=bmm)} for a full list of available help topics. } diff --git a/man/get_model_prior.Rd b/man/get_model_prior.Rd index 6530b7b3..387739c5 100644 --- a/man/get_model_prior.Rd +++ b/man/get_model_prior.Rd @@ -40,15 +40,15 @@ used if no user-specified priors were passed to the \code{\link[=fit_model]{fit_ \details{ The following models are supported: \itemize{ -\item \code{IMMabc(resp_err, nt_features, setsize)} -\item \code{IMMbsc(resp_err, nt_features, nt_distances, setsize)} -\item \code{IMMfull(resp_err, nt_features, nt_distances, setsize)} -\item \code{mixture2p(resp_err)} -\item \code{mixture3p(resp_err, nt_features, setsize)} -\item \code{sdmSimple(resp_err)} +\item IMMabc(resp_err, nt_features, setsize, regex) +\item IMMbsc(resp_err, nt_features, nt_distances, setsize, regex) +\item IMMfull(resp_err, nt_features, nt_distances, setsize, regex) +\item mixture2p(resp_err) +\item mixture3p(resp_err, nt_features, setsize, regex) +\item sdmSimple(resp_err) } -Type \code{?modelname} to get information about a specific model, e.g. \code{?IMMfull} +Type ?modelname to get information about a specific model, e.g. ?IMMfull Type \code{help(package=bmm)} for a full list of available help topics. } diff --git a/man/mixture3p.Rd b/man/mixture3p.Rd index e393a95b..6c49b7cd 100644 --- a/man/mixture3p.Rd +++ b/man/mixture3p.Rd @@ -4,7 +4,7 @@ \alias{mixture3p} \title{Three-parameter mixture model by Bays et al (2009).} \usage{ -mixture3p(resp_err, nt_features, setsize, ...) +mixture3p(resp_err, nt_features, setsize, regex = FALSE, ...) } \arguments{ \item{resp_err}{The name of the variable in the dataset containing @@ -13,13 +13,18 @@ the to-be-recalled target in radians. You can transform the response error in degrees to radians using the \code{deg2rad} function.} \item{nt_features}{A character vector with the names of the non-target -feature values. The non_target feature values should be in radians and centered -relative to the target.} +feature values. The non_target feature values should be in radians and +centered relative to the target. Alternatively, if regex=TRUE, a regular +expression can be used to match the non-target feature columns in the +dataset.} \item{setsize}{Name of the column containing the set size variable (if setsize varies) or a numeric value for the setsize, if the setsize is fixed.} +\item{regex}{Logical. If TRUE, the \code{nt_features} argument is interpreted as +a regular expression to match the non-target feature columns in the dataset.} + \item{...}{used internally for testing, ignore it} } \value{ @@ -72,13 +77,25 @@ ff <- bmmformula( thetant ~ 1 ) -# specify the 3-parameter model -model <- mixture3p(resp_err = "y", nt_features = paste0('nt',1:3,'_loc'), setsize = 4) +# specify the 3-parameter model with explicit column names for non-target features +model1 <- mixture3p(resp_err = "y", nt_features = paste0('nt',1:3,'_loc'), setsize = 4) + +# fit the model +fit <- fit_model(formula = ff, + data = dat, + model = model1, + parallel=T, + iter = 500, + backend='cmdstanr') + +# alternatively specify the 3-parameter model with a regular expression to match non-target features +# this is equivalent to the previous call, but more concise +model2 <- mixture3p(resp_err = "y", nt_features = "nt.*_loc", setsize = 4, regex = TRUE) # fit the model fit <- fit_model(formula = ff, data = dat, - model = model, + model = model2, parallel=T, iter = 500, backend='cmdstanr') diff --git a/tests/testthat/test-fit_model.R b/tests/testthat/test-fit_model.R index 36f0135a..9b61eaae 100644 --- a/tests/testthat/test-fit_model.R +++ b/tests/testthat/test-fit_model.R @@ -69,7 +69,7 @@ test_that('Available models produce expected errors', { test_args <- lapply(args_list, function(x) {NULL}) model <- brms::do_call(model, test_args) expect_error(fit_model(bmmformula(kappa~1), model=model, backend="mock", mock_fit=1, rename=FALSE), - "Data must be specified using the 'data' argument.") + 'argument "data" is missing, with no default') } diff --git a/tests/testthat/test-helpers-model.R b/tests/testthat/test-helpers-model.R index 1c102e4c..f9778911 100644 --- a/tests/testthat/test-helpers-model.R +++ b/tests/testthat/test-helpers-model.R @@ -21,6 +21,52 @@ test_that("check_model() refuses invalid models and accepts valid models", { } }) +test_that("check_model() works with regular expressions", { + dat <- OberauerLin_2017 + models1 <- list(mixture3p("dev_rad", + nt_features = paste0('col_nt',1:7), + setsize = 'set_size'), + IMMfull('dev_rad', + nt_features = paste0('col_nt',1:7), + nt_distances = paste0('dist_nt',1:7), + setsize = 'set_size'), + IMMbsc('dev_rad', + nt_features = paste0('col_nt',1:7), + nt_distances = paste0('dist_nt',1:7), + setsize = 'set_size'), + IMMabc('dev_rad', + nt_features = paste0('col_nt',1:7), + setsize = 'set_size') + ) + models2 <- list(mixture3p("dev_rad", + nt_features = 'col_nt', + setsize = 'set_size', + regex = TRUE), + IMMfull('dev_rad', + nt_features = 'col_nt', + nt_distances = 'dist_nt', + setsize = 'set_size', + regex = TRUE), + IMMbsc('dev_rad', + nt_features = 'col_nt', + nt_distances = 'dist_nt', + setsize = 'set_size', + regex = TRUE), + IMMabc('dev_rad', + nt_features = 'col_nt', + setsize = 'set_size', + regex = TRUE) + ) + + for (i in 1:length(models1)) { + check1 <- check_model(models1[[i]], dat) + check2 <- check_model(models2[[i]], dat) + attributes(check1) <- NULL + attributes(check2) <- NULL + expect_equal(check1, check2) + } +}) + test_that("use_model_template() prevents duplicate models", { skip_on_cran() okmodels <- supported_models(print_call=FALSE) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 9b5de0be..3374cd91 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -16,3 +16,21 @@ test_that("empty dots don't crash the function", { out <- combine_args(nlist(config_args)) expect_equal(out, list(formula = 'a', family = 'b', prior = 'c', data = 'd', stanvars = 'e', init = 1)) }) + + +test_that("get_variables works", { + expect_equal(get_variables('a', c('a', 'b', 'c')), 'a') + expect_equal(get_variables('a', c('a', 'b', 'c'), regex = TRUE), 'a') + expect_equal(get_variables('a', c('a', 'b', 'c'), regex = FALSE), 'a') + expect_equal(get_variables('a|b', c('a', 'b', 'c'), regex = TRUE), c('a', 'b')) + expect_equal(get_variables('abc', c('abc1', 'abc2', 'abc3', 'other'), regex = TRUE), + c('abc1', 'abc2', 'abc3')) + expect_equal(get_variables('^abc', c('abc1', 'abc2', 'abc3', 'other_abc4'), regex = TRUE), + c('abc1', 'abc2', 'abc3')) + expect_equal(get_variables('abc$', c('nt1_abc', 'nt2_abc', 'nt3_abc', 'other_abc4'), regex = TRUE), + c('nt1_abc', 'nt2_abc', 'nt3_abc')) + expect_equal(get_variables('nt.*_abc', c('nt1_abc', 'nt2_abc', 'nt3_abc', 'other_abc4'), regex = TRUE), + c('nt1_abc', 'nt2_abc', 'nt3_abc')) + expect_equal(get_variables('a|b', c('a', 'b', 'c'), regex = FALSE), 'a|b') + expect_error(get_variables('d', c('a', 'b', 'c'), regex = TRUE)) +}) diff --git a/vignettes/IMM.Rmd b/vignettes/IMM.Rmd index a2b21974..dc8bfd5d 100644 --- a/vignettes/IMM.Rmd +++ b/vignettes/IMM.Rmd @@ -98,7 +98,7 @@ library(bmm) ## Generating simulated data -Should you already have a data set you want to fit, you can skip this section. Alternatively, you can use data provided with the package (add reference to data) or generate data using the random generation function provided in the `bmm` package. +Should you already have a data set you want to fit, you can skip this section. Alternatively, you can use data provided with the package (see `data(package='bmm')`) or generate data using the random generation function provided in the `bmm` package. ```{r} # set seed for reproducibility @@ -172,11 +172,22 @@ Then, we can specify the model that we want to estimate. This includes specifyin ```{r} model <- IMMfull(resp_err = "resp_err", - nt_features = paste0("color_item",2:setsize), + nt_features = paste0("color_item",2:5), setsize = setsize, - nt_distances = paste0("dist_item",2:setsize)) + nt_distances = paste0("dist_item",2:5)) ``` +In the above example we specified all column names for the non_targets explicitely via `paste0('color_item',2:5)`. Alternatively, you can use a regular expression to match the non-target feature columns in the dataset. For example, you can specify the model a few different ways via regular expressions: + +```{r} +model <- IMMfull(resp_err = "resp_err", + nt_features = "color_item[2-5]", + setsize = setsize, + nt_distances = "dist_item[2-5]", + regex = TRUE) +``` + + Finally, we can fit the model by passing all the relevant arguments to the `fit_model` function: ``` r @@ -247,10 +258,12 @@ as.data.frame(draws) %>% ggplot(aes(value, par)) + tidybayes::stat_halfeyeh(normalize = "groups") + geom_point(data = data.frame(par = colnames(draws), - value = c(Cs,As,Ss,kappas)), + value = c(kappas, As, Cs, Ss)), aes(value,par), color = "red", shape = "diamond", size = 2.5) + scale_x_continuous(lim=c(0,20)) + +colnames(draws) ``` # References diff --git a/vignettes/mixture_models.Rmd b/vignettes/mixture_models.Rmd index d8654af6..0d9e2a49 100644 --- a/vignettes/mixture_models.Rmd +++ b/vignettes/mixture_models.Rmd @@ -300,4 +300,19 @@ fit3p <- fit_model( The rest of the analysis is the same as for the 2-parameter model. We can inspect the model fit, extract the parameter estimates, and visualize the posterior distributions. +In the above example we specified all column names for the non_targets explicitely via `paste0('non_target_',1:5)`. Alternatively, you can use a regular expression to match the non-target feature columns in the dataset. This is useful when the non-target feature columns are named in a consistent way, e.g. `non_target_1`, `non_target_2`, `non_target_3`, etc. For example, you can specify the model a few different ways via regular expressions: + +```{r} +model <- mixture3p(resp_err = "error", + nt_features = "non_target_[1-5]", + setsize = 'set_size', + regex = TRUE) +model <- mixture3p(resp_err = "error", + nt_features = "non_target_", + setsize = 'set_size', + regex = TRUE) +``` + + + # References