Skip to content

Commit

Permalink
Merge branch 'develop' into feature/108-implement-sdt-models
Browse files Browse the repository at this point in the history
  • Loading branch information
GidonFrischkorn committed Feb 21, 2024
2 parents 3b9c0bc + a89723c commit e1274e1
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 47 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ export(rad2deg)
export(rmixture2p)
export(rmixture3p)
export(rsdm)
export(save_pars)
export(sdmSimple)
export(set_default_prior)
export(softmax)
export(softmaxinv)
export(supported_models)
export(use_model_template)
export(wrap)
importFrom(brms,save_pars)
importFrom(magrittr,"%>%")
1 change: 1 addition & 0 deletions R/bmm-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
"_PACKAGE"

## usethis namespace: start
#' @importFrom brms save_pars
## usethis namespace: end
NULL
109 changes: 62 additions & 47 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ NULL
#' softmaxinv(softmax(5:7))
softmax <- function(eta, lambda = 1) {
stopifnot(requireNamespace("matrixStats", quietly = TRUE))
m <- length(eta)+1
DEN <- matrixStats::logSumExp(c(lambda*eta, 0))
LSOFT <- c(lambda*eta, 0) - DEN
m <- length(eta) + 1
DEN <- matrixStats::logSumExp(c(lambda * eta, 0))
LSOFT <- c(lambda * eta, 0) - DEN
exp(LSOFT)
}

Expand All @@ -53,7 +53,7 @@ softmax <- function(eta, lambda = 1) {
softmaxinv <- function(p, lambda = 1) {
m <- length(p)
if (m > 1) {
return((log(p) - log(p[m]))[1:(m-1)]/lambda)
return((log(p) - log(p[m]))[1:(m - 1)] / lambda)
}
numeric(0)
}
Expand All @@ -70,25 +70,26 @@ softmaxinv <- function(p, lambda = 1) {
#' propagated to the user environment.
#' @keywords internal
#' @returns A list of options to pass to brm()
configure_options <- function(opts, env=parent.frame()) {
configure_options <- function(opts, env = parent.frame()) {
if (opts$parallel) {
cores = parallel::detectCores()
if (opts$chains > parallel::detectCores()) {
cores <- parallel::detectCores()
if (opts$chains > parallel::detectCores()) {
opts$chains <- parallel::detectCores()
}
} else {
cores = NULL
cores <- NULL
}
withr::local_options(
list(
mc.cores = cores,
bmm.silent = opts$silent,
bmm.sort_data = opts$sort_data
),
.local_envir=env)
.local_envir = env
)

# return only options that can be passed to brms/rstan/cmdstanr
exclude_args <- c('parallel', 'sort_data')
exclude_args <- c("parallel", "sort_data")
opts[not_in(names(opts), exclude_args)]
}

Expand All @@ -108,17 +109,17 @@ not_in_list <- function(key, list) {
#' wrappers to construct a brms nlf and lf formulas from multiple string arguments
#' @param ... string parts of the formula separated by commas
#' @examples
#' kappa_nts <- paste0('kappa_nt', 1:4)
#' glue_nlf(kappa_nts[i], ' ~ kappa') ## same as brms::nlf(kappa_nt1 ~ kappa)
#' kappa_nts <- paste0("kappa_nt", 1:4)
#' glue_nlf(kappa_nts[i], " ~ kappa") ## same as brms::nlf(kappa_nt1 ~ kappa)
#' @noRd
glue_nlf <- function(...) {
dots = list(...)
dots <- list(...)
brms::nlf(stats::as.formula(collapse(...)))
}

# like glue_nlf but for lf formulas
glue_lf <- function(...) {
dots = list(...)
dots <- list(...)
brms::lf(stats::as.formula(collapse(...)))
}

Expand All @@ -130,7 +131,7 @@ glue_lf <- function(...) {
call_brm <- function(fit_args) {
fit <- brms::do_call(brms::brm, fit_args)
fit$bmm_fit_args <- fit_args
class(fit) <- c('bmmfit','brmsfit')
class(fit) <- c("bmmfit", "brmsfit")
fit
}

Expand All @@ -145,10 +146,10 @@ combine_args <- function(args) {
return(c(config_args, opts))
}
for (i in names(dots)) {
if (not_in(i, c('family'))) {
if (not_in(i, c("family"))) {
config_args[[i]] <- dots[[i]]
} else {
stop('You cannot provide a family argument to fit_model. Please use the model argument instead.')
stop("You cannot provide a family argument to fit_model. Please use the model argument instead.")
}
}
c(config_args, opts)
Expand All @@ -157,7 +158,7 @@ combine_args <- function(args) {


message2 <- function(...) {
silent <- getOption('bmm.silent', 1)
silent <- getOption("bmm.silent", 1)
if (silent < 2) {
message(...)
}
Expand All @@ -166,7 +167,7 @@ message2 <- function(...) {


# function to ensure proper reading of stan files
read_lines2 <- function (con) {
read_lines2 <- function(con) {
lines <- readLines(con, n = -1L, warn = FALSE)
paste(lines, collapse = "\n")
}
Expand All @@ -175,14 +176,14 @@ read_lines2 <- function (con) {
# for testing purposes
install_and_load_bmm_version <- function(version) {
if ("package:bmm" %in% search()) {
detach("package:bmm", unload=TRUE)
detach("package:bmm", unload = TRUE)
}
path <- paste0(.libPaths()[1], "/bmm-", version)
if (!dir.exists(path) || length(list.files(path)) == 0 || length(list.files(paste0(path, "/bmm"))) == 0) {
dir.create(path)
remotes::install_github(paste0("venpopov/bmm@",version), lib=path)
remotes::install_github(paste0("venpopov/bmm@", version), lib = path)
}
library(bmm, lib.loc=path)
library(bmm, lib.loc = path)
}


Expand All @@ -207,8 +208,8 @@ fit_info.brmsfit <- function(fit, what) {
fit_attr <- attributes(fit$fit)
metadata <- fit_attr$metadata
switch(what,
time = metadata$time$chains,
time_mean = colMeans(metadata$time$chains),
time = metadata$time$chains,
time_mean = colMeans(metadata$time$chains),
)
}

Expand Down Expand Up @@ -248,24 +249,32 @@ stop_quietly <- function() {
# ordered by the predictors, and if not, it suggests to the user to sort the data
order_data_query <- function(model, data, formula) {
sort_data <- getOption("bmm.sort_data", NULL)
if(is.null(sort_data) & !is_data_ordered(data, formula)) {
message("\n\nData is not ordered by predictors.\nYou can speed up the model ",
"estimation up to several times (!) by ordering the data by all your ",
"predictor columns.\n\n")
if (is.null(sort_data) & !is_data_ordered(data, formula)) {
message(
"\n\nData is not ordered by predictors.\nYou can speed up the model ",
"estimation up to several times (!) by ordering the data by all your ",
"predictor columns.\n\n"
)
caution_msg <- paste(strwrap("* caution: if you chose Option 2, you need to be careful
when using brms postprocessing methods that rely on the data order, such as
generating predictions. Assuming you assigned the result of fit_model to a
variable called `fit`, you can extract the sorted data from the fitted object
with:\n\n data_sorted <- fit$fit_args$data", width=80), collapse = "\n")
with:\n\n data_sorted <- fit$fit_args$data", width = 80), collapse = "\n")
caution_msg <- crayon::red(caution_msg)

if(interactive()) {
var <- utils::menu(c("Yes (note: you will receive code to sort your data)",
paste0("Let bmm sort the data for you and continue with the faster model fitting ",
crayon::red("(*)")),
paste0("No, I want to continue with the slower estimation\n\n", caution_msg, collapse = "\n")),
title="Do you want to stop and sort your data? (y/n): ")
if(var == 1) {
if (interactive()) {
var <- utils::menu(
c(
"Yes (note: you will receive code to sort your data)",
paste0(
"Let bmm sort the data for you and continue with the faster model fitting ",
crayon::red("(*)")
),
paste0("No, I want to continue with the slower estimation\n\n", caution_msg, collapse = "\n")
),
title = "Do you want to stop and sort your data? (y/n): "
)
if (var == 1) {
message("Please sort your data by all predictors and then re-run the model.")
data_name <- attr(data, "data_name")
if (is.null(data_name)) {
Expand All @@ -274,26 +283,27 @@ order_data_query <- function(model, data, formula) {
message("To sort your data, use the following code:\n\n")
message(crayon::green("library(dplyr)"))
message(crayon::green(data_name, "_sorted <- ", data_name, " %>% arrange(",
paste(rhs_vars(formula), collapse = ", "),
")\n\n",
sep=""))
paste(rhs_vars(formula), collapse = ", "),
")\n\n",
sep = ""
))
message("Then re-run the model with the newly sorted data.")
stop_quietly()
} else if (var == 2) {
message("Your data has been sorted by the following predictors: ", paste(rhs_vars(formula), collapse = ", "),'\n')
message("Your data has been sorted by the following predictors: ", paste(rhs_vars(formula), collapse = ", "), "\n")
preds <- rhs_vars(formula)
data <- dplyr::arrange_at(data, preds)
}
}
} else if (isTRUE(sort_data)) {
preds <- rhs_vars(formula)
data <- dplyr::arrange_at(data, preds)
message("\nYour data has been sorted by the following predictors: ", paste(rhs_vars(formula), collapse = ", "),'\n')
message("\nYour data has been sorted by the following predictors: ", paste(rhs_vars(formula), collapse = ", "), "\n")
caution_msg <- paste(strwrap("* caution: you have set `sort_data=TRUE`. You need to be careful
when using brms postprocessing methods that rely on the data order, such as
generating predictions. Assuming you assigned the result of fit_model to a
variable called `fit`, you can extract the sorted data from the fitted object
with:\n\n data_sorted <- fit$fit_args$data", width=80), collapse = "\n")
with:\n\n data_sorted <- fit$fit_args$data", width = 80), collapse = "\n")
caution_msg <- crayon::red(caution_msg)
message(caution_msg)
}
Expand All @@ -316,15 +326,20 @@ aggregate_data <- function(model, formula, data) {
}

#' @export
aggregate_data.bmmmodel <- function(model, formula, data){
aggregate_data.bmmmodel <- function(model, formula, data) {
NextMethod("aggregate_data")
}

#' @export
aggregate_data.default <- function(model, formula, data){
aggregate_data.default <- function(model, formula, data) {
data
}




#' @inherit brms::save_pars title params return
#' @description Thin wrapper around [brms::save_pars()]. When calling
#' [fit_model] with additional information to save parameters you can use this
#' function to pass information about saving parameter draws to `brms` without
#' having to load `brms`. Alternatively, you can also load `brms` and call
#' `save_pars`. For details see ?brms::save_pars.
#' @export
save_pars <- brms::save_pars
43 changes: 43 additions & 0 deletions man/save_pars.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e1274e1

Please sign in to comment.