diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index dec6283..a5795e7 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -70,6 +70,7 @@ jobs: remotes waldo fansi + extraDistr - name: Build Cmdstan run: | diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml index f924600..5b4763a 100644 --- a/.github/workflows/pkgdown.yaml +++ b/.github/workflows/pkgdown.yaml @@ -52,6 +52,7 @@ jobs: remotes waldo fansi + extraDistr any::pkgdown local::. diff --git a/DESCRIPTION b/DESCRIPTION index 34de126..2e59873 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,7 +24,6 @@ Suggests: rmarkdown, bookdown, tidybayes, - ggplot2, mixtur, ggthemes, cowplot, @@ -32,7 +31,8 @@ Suggests: remotes, waldo, fansi, - cmdstanr (>= 0.7.0) + cmdstanr (>= 0.7.0), + extraDistr Config/testthat/edition: 3 Imports: magrittr, @@ -45,7 +45,11 @@ Imports: matrixStats, crayon, methods, - fs + fs, + ggplot2, + distributional, + ggdist, + rlang URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/ BugReports: https://github.com/venpopov/bmm/issues Additional_repositories: diff --git a/NAMESPACE b/NAMESPACE index 5b27203..948974c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -5,6 +5,9 @@ S3method("[",bmmformula) S3method("[<-",bmmformula) S3method(add_links,bmmfit) S3method(add_links,bmmodel) +S3method(add_links,brmsprior) +S3method(add_mu,brmsprior) +S3method(add_mu,brmssummary) S3method(bmf2bf,bmmodel) S3method(check_data,bmmodel) S3method(check_data,default) @@ -33,6 +36,10 @@ S3method(configure_prior,mixture3p) S3method(default_prior,bmmformula) S3method(fit_info,brmsfit) S3method(fit_info,brmsfit_list) +S3method(get_links,bmmodel) +S3method(get_links,brmsformula) +S3method(get_links,default) +S3method(get_links,mvbrmsformula) S3method(identical,bmmformula) S3method(identical,brmsformula) S3method(identical,default) @@ -42,6 +49,9 @@ S3method(is_constant,default) S3method(is_nl,bmmformula) S3method(is_nl,default) S3method(model_info,bmmodel) +S3method(plot,brmsprior) +S3method(plot,character) +S3method(plot,distribution) S3method(postprocess_brm,bmmodel) S3method(postprocess_brm,default) S3method(postprocess_brm,sdm) @@ -110,9 +120,13 @@ export(sdm) export(sdmSimple) export(softmax) export(softmaxinv) +export(stat_slab2) export(supported_models) +export(theme_dist) export(use_model_template) +export(vens_options) export(wrap) +import(ggplot2) import(stats) importFrom(brms,default_prior) importFrom(brms,restructure) diff --git a/R/brms-misc.R b/R/brms-misc.R index fd3e4be..d9fa789 100644 --- a/R/brms-misc.R +++ b/R/brms-misc.R @@ -1,3 +1,6 @@ +# Functions in this file are copies of `brms` internal functions that we are +# not allowed to import. All copyrights belong to the original author + # check if x is a try-error resulting from try() is_try_error <- function(x) { inherits(x, "try-error") @@ -282,3 +285,75 @@ rename <- function(x, pattern = NULL, replacement = NULL, } out } + + +# apply a link function +# @param x an array of arbitrary dimension +# @param link character string defining the link +link <- function(x, link) { + switch(link, + identity = x, + log = log(x), + logm1 = brms::logm1(x), + log1p = log1p(x), + inverse = 1 / x, + sqrt = sqrt(x), + "1/mu^2" = 1 / x^2, + tan_half = tan(x / 2), + logit = logit(x), + probit = qnorm(x), + cauchit = qcauchy(x), + cloglog = cloglog(x), + probit_approx = qnorm(x), + softplus = log_expm1(x), + squareplus = (x^2 - 1) / x, + softit = softit(x), + stop2("Link '", link, "' is not supported.") + ) +} + +# apply an inverse link function +# @param x an array of arbitrary dimension +# @param link a character string defining the link +inv_link <- function(x, link) { + switch(link, + identity = x, + log = exp(x), + logm1 = brms::expp1(x), + log1p = expm1(x), + inverse = 1 / x, + sqrt = x^2, + "1/mu^2" = 1 / sqrt(x), + tan_half = 2 * atan(x), + logit = inv_logit(x), + probit = pnorm(x), + cauchit = pcauchy(x), + cloglog = inv_cloglog(x), + probit_approx = pnorm(x), + softplus = log1p_exp(x), + squareplus = (x + sqrt(x^2 + 4)) / 2, + softit = inv_softit(x), + stop2("Link '", link, "' is not supported.") + ) +} + + +conv_cats_dpars <- function (family) { + is_categorical(family) || is_multinomial(family) || is_simplex(family) +} + +is_categorical <- function(family) { + 'categorical' %in% family$specials +} + +is_multinomial <- function(family) { + 'multinomial' %in% family$specials +} + +is_simplex <- function(family) { + 'simplex' %in% family$specials +} + +ulapply <- function (X, FUN, ..., recursive = TRUE, use.names = TRUE) { + unlist(lapply(X, FUN, ...), recursive, use.names) +} diff --git a/R/distributions.R b/R/distributions.R index 834c255..065f453 100644 --- a/R/distributions.R +++ b/R/distributions.R @@ -58,26 +58,26 @@ #' @examples #' # plot the density of the SDM distribution #' x <- seq(-pi,pi,length.out=10000) -#' plot(x,dsdm(x,0,2,3),type="l", xlim=c(-pi,pi),ylim=c(0,1), -#' xlab="Angle error (radians)", -#' ylab="density", -#' main="SDM density") -#' lines(x,dsdm(x,0,9,1),col="red") -#' lines(x,dsdm(x,0,2,8),col="green") -#' legend("topright",c("c=2, kappa=3.0, mu=0", -#' "c=9, kappa=1.0, mu=0", -#' "c=2, kappa=8, mu=1"), -#' col=c("black","red","green"),lty=1, cex=0.8) +#' plot(x, dsdm(x,0,2,3), type="l", xlim = c(-pi,pi), ylim = c(0,1), +#' xlab = "Angle error (radians)", +#' ylab = "density", +#' main = "SDM density") +#' lines(x, dsdm(x,0,9,1), col="red") +#' lines(x, dsdm(x,0,2,8), col="green") +#' legend("topright", c("c=2, kappa=3.0, mu=0", +#' "c=9, kappa=1.0, mu=0", +#' "c=2, kappa=8, mu=1"), +#' col=c("black","red","green"), lty = 1, cex = 0.8) #' #' # plot the cumulative distribution function of the SDM distribution #' p <- psdm(x, mu = 0, c = 3.1, kappa = 5) -#' plot(x,p,type="l") +#' plot(x, p, type = "l") #' #' # generate random deviates from the SDM distribution and overlay the density #' r <- rsdm(10000, mu = 0, c = 3.1, kappa = 5) #' d <- dsdm(x, mu = 0, c = 3.1, kappa = 5) -#' hist(r, breaks=60, freq=FALSE) -#' lines(x,d,type="l", col="red") +#' hist(r, breaks = 60, freq = FALSE) +#' lines(x, d, type = "l", col = "red") #' dsdm <- function(x, mu = 0, c = 3, kappa = 3.5, log = FALSE, parametrization = "sqrtexp") { @@ -116,7 +116,7 @@ psdm <- function(q, mu = 0, c = 3, kappa = 3.5, lower.tail = TRUE, log.p = FALSE pi <- base::pi pi2 <- 2 * pi - q <- (q + pi) %% pi2 + q <- ifelse(q == pi, pi2, (q + pi) %% pi2) mu <- (mu + pi) %% pi2 lower.bound <- (lower.bound + pi) %% pi2 @@ -147,7 +147,16 @@ psdm <- function(q, mu = 0, c = 3, kappa = 3.5, lower.tail = TRUE, log.p = FALSE #' @rdname SDMdist #' @export qsdm <- function(p, mu=0, c=3, kappa=3.5, parametrization = "sqrtexp") { - .NotYetImplemented() + p <- ifelse(near(p,1), p-1e-7, p) + + .qsdm <- function(p, mu, c, kappa, parametrization) { + uniroot(function(x) psdm(x, mu = mu, c = c, kappa = kappa, + lower.tail = TRUE, log.p = FALSE, + lower.bound = -pi, parametrization = parametrization) - p, + interval = c(-pi, pi))$root + } + .qsdm_v <- Vectorize(.qsdm) + .qsdm_v(p, mu, c, kappa, parametrization) } #' @rdname SDMdist diff --git a/R/helpers-parameters.R b/R/helpers-parameters.R index 86422b1..2b67ced 100644 --- a/R/helpers-parameters.R +++ b/R/helpers-parameters.R @@ -18,8 +18,8 @@ #' SDs_degress <- SDs * 180 / pi #' #' # plot the relationship between kappa and circular SD -#' plot(kappas,SDs) -#' plot(kappas,SDs_degress) +#' plot(kappas, SDs) +#' plot(kappas, SDs_degress) #' k2sd <- function(K) { S <- matrix(0, 1, length(K)) diff --git a/R/helpers-prior.R b/R/helpers-prior.R index c627ea6..eea1351 100644 --- a/R/helpers-prior.R +++ b/R/helpers-prior.R @@ -34,22 +34,24 @@ #' @export default_prior.bmmformula <- function(object, data, model, formula = object, ...) { withr::local_options(bmm.sort_data = FALSE) + dots <- list(...) formula <- object model <- check_model(model, data, formula) data <- check_data(model, data, formula) formula <- check_formula(model, data, formula) config_args <- configure_model(model, data, formula) - prior <- configure_prior(model, data, config_args$formula, user_prior = NULL) + prior <- configure_prior(model, data, config_args$formula, user_prior = dots$prior) + dots$prior <- NULL - dots <- list(...) prior_args <- combine_args(nlist(config_args, dots, prior)) prior_args$object <- prior_args$formula prior_args$formula <- NULL brms_priors <- brms::do_call(brms::default_prior, prior_args) - combine_prior(brms_priors, prior_args$prior) + prior <- combine_prior(brms_priors, prior_args$prior) + add_links(prior, model) } @@ -273,3 +275,39 @@ summarise_default_prior <- function(prior_list) { } prior_info } + + +# preprocess a brmsprior object for plotting +prep_brmsprior <- function(x, formula = NULL, links = NULL) { + stopif(!brms::is.brmsprior(x), "x must be a brmsprior object") + if (!is.null(formula)) { + stopif(!is.null(links), "Either formula or links should be provided") + stopif(!is_brmsformula(formula), "formula must be a brmsformula object") + stopif(is.null(formula$family), "formula must contain a brmsfamily object") + link_source <- formula + } else { + stopif(!is.null(links) && (!is.list(links) || is.null(names(links))), + "links must be a named list") + link_source <- links + } + + priors <- x[x$prior != "",] + priors <- add_mu(priors) + priors <- priors[!grepl("constant", priors$prior),] + + # rename priors to match distribution functions + priors$prior <- gsub("logistic", "logis", priors$prior) + priors$prior <- gsub("inv_gamma", "invgamma", priors$prior) + + # get the parameter name regardless of type + priors$par <- ifelse(priors$nlpar != "", priors$nlpar, priors$dpar) + + # remove corr parameters as it's too complicated to visualize them + priors <- priors[priors$class != "cor",] + + # add a column with the corresponding parameter link + if (!("link" %in% names(x))) { + priors <- add_links(priors, object = link_source) + } + priors +} diff --git a/R/math.R b/R/math.R new file mode 100644 index 0000000..5251c03 --- /dev/null +++ b/R/math.R @@ -0,0 +1,39 @@ +logit <- function(p) { + log(p) - log1p(-p) +} + +inv_logit <- function(x) { + 1 / (1 + exp(-1*x)) +} + +cloglog <- function (x) { + log(-log1p(-x)) +} + +inv_cloglog <- function(x) { + 1 - exp(-1*exp(x)) +} + +log_expm1 <- function (x) { + out <- log(expm1(x)) + ifelse(out < Inf, out, x) +} + +log1p_exp <- function (x) { + out <- log1p(exp(x)) + ifelse(out < Inf, out, x) +} + +log1p_exp_custom <- function(x, mu = 0, a = 0, b = 1) { + out <- log1p(a + exp(b * (x - mu)))/b + ifelse(out < Inf, out, x) +} + +softit <- function(x) { + log_expm1(x/(1-x)) +} + +inv_softit <- function(x) { + y <- log1p_exp(x) + y/(1 + y) +} diff --git a/R/plot.R b/R/plot.R new file mode 100644 index 0000000..6fb1ce5 --- /dev/null +++ b/R/plot.R @@ -0,0 +1,231 @@ +#' Plot a distribution object from the `distributional` package +#' +#' Uses `ggplot2` and `ggdist` to create a line plot of the probability density +#' function (pdf) or cumulative distribution function (cdf) of a distribution +#' object from the `distributional` package or a character vector of +#' distribution names of the type `distname(param1 = value1, param2 = value2)` +#' +#' @name plot-distribution +#' +#' @param x Several options. One of: +#' - An object of class `distribution` from the `distributional` package +#' - A character vector of distribution names that can be parsed by +#' `ggdist::parse_dist()` of the type `distname(param1 = value1, param2 = +#' value2)` +#' - A `brmsprior` object produced by [brms::set_prior], [brms::default_prior] or +#' [brms::prior_summary()] +#' @param ... additional distribution objects to plot. +#' @param formula A `brmsformula` object. Only needed for plotting `brmsprior` +#' objects if transformation of the priors is wanted via the specified links +#' in the `brmsfamily`. The `brmsformula` must contain the `brmsfamily` +#' @param links A named list of links for the parameters. This is an alternative +#' way to specify the links for the parameters. Only one of `links` or +#' `formula` can be provided. +#' @param transform Logical. If `TRUE`, the priors are transformed according to +#' the inverse of the link functions present either in the `brmsformula` or +#' the `links` argument. Default is `FALSE`. +#' @param type The type of plot. One of "pdf" or "cdf". +#' @param labels A character vector of labels for the distributions. If `NULL`, +#' the labels are automatically generated from the distribution objects. If a +#' character vector is provided, it must be the same length as the number of +#' distributions being plotted. Default is `NULL`. +#' @param facets Logical. If `TRUE`, the distributions are plotted in separate +#' facets. If `FALSE`, all distributions are plotted in the same plot. Default +#' is `FALSE`. +#' @param stat_slab_control A list of additional arguments passed to +#' `ggdist::stat_slab()` +#' @param packages A character vector of package names to search for +#' distributions +#' +#' @details These are generic methods for plotting distributions specified as +#' - `distribution` objects from the `distributional` package +#' - character vectors of distribution names that can be parsed by [ggdist::parse_dist()] +#' - `brmsprior` objects produced by [brms::set_prior], [brms::default_prior] or +#' [brms::prior_summary()] +#' +#' For `brmsprior` objects, you can also plot the transformations of the priors +#' to the native scale of the parameters by setting `transform = TRUE`. This +#' will transform the priors according to the inverse of the link functions +#' present either in the `brmsformula` or the `links` argument. For example, if +#' the prior is specified as `normal(0, 1)` and the link function is `log`, as +#' it is often the case for parameters that need to be strictly positive, the +#' prior will be transformed to `lognormal(0, 1)` by applying `exp()` to the +#' distribution object. +#' +#' @return A ggplot object +#' @examples +#' x <- distributional::dist_normal(mean = 0, sd = c(0.5,1,1.2)) +#' plot(x) +#' +#' x2 <- distributional::dist_wrap("sdm", mu = 0, c = c(1:10), kappa = 3) +#' plot(x2) +#' +#' plot('normal(0, 1)', 'normal(0, 2)', type = 'cdf') +#' plot(c('normal(0, 1)', 'student_t(3,0,2.5)'), facets = T) +#' +#' plot(distributional::dist_beta(1:10, 1)) + theme(legend.position = 'right') +#' @keywords plot +#' @export +plot.distribution <- function(x, ..., type = 'pdf', labels = NULL, facets = FALSE, + stat_slab_control = list()) { + if (!requireNamespace('extraDistr')) { + stop2("The 'extraDistr' package is required for plotting distributions") + } + dots <- list(...) + if (length(dots) > 0) { + stopif(any(!sapply(dots, inherits, what = "distribution")), + "All additional arguments passed via ... must be of class 'distribution'") + additional <- do.call(c, dots) + x <- c(x, additional) + } + + labels <- labels %||% format(x) + labels <- factor(labels, levels = unique(labels)) + + + df <- data.frame(x = x, labels = labels) + out <- ggplot(df) + + aes(xdist = x, thickness = after_stat(eval(parse(text = type)))) + + theme_dist() + + theme(legend.position = "bottom") + + stat_slab_args <- stat_slab_control + stat_slab_args$fill <- stat_slab_args$fill %||% NA + stat_slab_args$subguide <- stat_slab_args$subguide %||% + ggdist::subguide_outside(title = "density") + stat_slab_args$subscale <- stat_slab_args$subscale %||% + ggdist::subscale_thickness(expand = expansion(c(0, 0.1))) + stat_slab_args$scale <- stat_slab_args$scale %||% 1 + stat_slab_args$p_limits <- stat_slab_args$p_limits %||% c(0.001, 0.999) + if (facets) { + stat_slab_args$color <- stat_slab_args$color %||% "grey30" + stat_slab_args$normalize <- stat_slab_args$normalize %||% "groups" + } else { + stat_slab_args$normalize <- stat_slab_args$normalize %||% "all" + } + + out <- out + do.call(ggdist::stat_slab, stat_slab_args) + + if (facets) { + out <- out + facet_wrap(~labels, scales = "free") + } else { + out <- out + aes(color = labels) + scale_color_discrete("") + } + + out +} + + +#' @rdname plot-distribution +#' @export +plot.character <- function(x, ..., type = 'pdf', labels = NULL, facets = FALSE, + stat_slab_control = list(), + packages = c("brms","extraDistr","bmm","ggdist")) { + search_env <- pkg_search_env(packages) + dots <- list(...) + dists <- unlist(c(x, dots)) + stopifnot(all(is.character(dists))) + dists <- ggdist::parse_dist(dists, package = search_env) + plot(dists$.dist_obj, type = type, labels = labels, facets = facets, + stat_slab_control = stat_slab_control) +} + + +#' @import ggplot2 +#' @rdname plot-distribution +#' @export +plot.brmsprior <- function(x, formula = NULL, links = NULL, transform = FALSE, + type = 'pdf', facets = TRUE, stat_slab_control = list(), + packages = c("brms","extraDistr","bmm","ggdist"), ...) { + prior <- prep_brmsprior(x, formula = formula, links = links) + prior <- ggdist::parse_dist(prior, package = pkg_search_env(packages)) + + if (transform) { + stopif(is.null(formula) && is.null(links) && !("link" %in% names(prior)), + glue("If transform = TRUE, either formula or links must be provided, \\ + or there must be a 'link' column in the brmsprior object")) + for (i in 1:nrow(prior)) { + prior$.dist_obj[i] <- inv_link(prior$.dist_obj[[i]], prior$link[[i]]) + } + } + + + prior$labels <- glue("{prior$resp}_{prior$class}_{prior$par}_{prior$coef}_{prior$group}") + prior$labels <- gsub("_+", "_", prior$labels) + prior$labels <- gsub("(^_|_$)", "", prior$labels) + prior$labels <- paste0(prior$labels, " ~ ", format(prior$.dist_obj)) + prior <- prior[order(prior$par, prior$class),] + prior$labels <- factor(prior$labels, levels = unique(prior$labels)) + plot(prior$.dist_obj, type = type, facets = facets, labels = prior$labels, + stat_slab_control = stat_slab_control) +} + + +pkg_search_env <- function(packages) { + stopif(any(!is.character(packages))) + search_env <- globalenv() + search_env <- c(search_env, rlang::search_envs()) + search_env <- c(search_env, lapply(packages, function(x) asNamespace(x))) + search_env <- unname(search_env) + as.environment(do.call(c, lapply(search_env, as.list))) +} + + +# construct a string representation of a distribution object +dist2string <- function(x, par_names = FALSE) { + family <- stats::family(x) + parameters <- distributional::parameters(x) + if (family == "wrap") { + family <- parameters$dist + parameters$dist <- NULL + } + if (!par_names) { + parameters <- paste0(parameters, collapse = ", ") + } else { + parameters <- paste0(names(parameters), " = ", parameters, collapse = ", ") + } + out <- glue::glue("{family}({parameters})") + out <- gsub("list\\(", "", out) + out <- gsub("\\))", ")", out) + out +} + + +#' ggplot theme for plotting distributions +#' +#' Based on `theme_ggdist()` from the `ggdist` package, but with the y-axis +#' removed and margins adjusted to make space for a custom y axis +#' @return A ggplot theme object +#' @keywords plot +#' @export +theme_dist <- function() { + ggdist::theme_ggdist() + + theme(plot.margin = margin(5.5, 5.5, 5.5, 50), + axis.text.y = element_blank(), + axis.ticks.y = element_blank(), + axis.title.y = element_blank()) +} + + +# TODO: add this after ggdist next release +#' #' ggdist::stat_slab() with nice defaults +#' #' +#' #' A wrapper for ggdist::stat_slab() that sets default subguides and scales to +#' #' display the density on the y axis. +#' #' @param ... Additional arguments passed to ggdist::stat_slab() +#' #' @inheritParams ggdist::geom_slab +#' #' @return a stat_slab() object +#' #' @keywords plot +#' #' @export +#' stat_slab2 <- function(..., +#' subguide = ggdist::subguide_outside(title = "density"), +#' subscale = if (packageVersion('ggdist') > "3.3.2") { +#' ggdist::subscale_thickness(expand = expansion(c(0, 0.05))) +#' } else { +#' NULL +#' }, +#' normalize = 'groups') { +#' +#' suppressWarnings(ggdist::stat_slab(..., subguide = subguide, subscale = subscale, normalize = normalize)) +#' } + diff --git a/R/restructure.R b/R/restructure.R index 7f711c0..920a439 100644 --- a/R/restructure.R +++ b/R/restructure.R @@ -92,23 +92,6 @@ restructure_version.bmm <- function(x) { out } -add_links <- function(x) { - UseMethod("add_links") -} - -#' @export -add_links.bmmfit <- function(x) { - x$bmm$model <- add_links(x$bmm$model) - x -} - -#' @export -add_links.bmmodel <- function(x) { - model_name <- class(x)[length(class(x))] - new_model <- get_model(model_name)() - x$links <- new_model$links - x -} add_bmm_info <- function(x) { env <- x$family$env diff --git a/R/summary.R b/R/summary.R index 9d36628..94fcf39 100644 --- a/R/summary.R +++ b/R/summary.R @@ -17,7 +17,7 @@ summary.bmmfit <- function(object, priors = FALSE, prob = 0.95, robust = FALSE, return(out) } - out <- rename_mu_smry(out, get_mu_pars(object)) + out <- add_mu(out, get_mu_pars(object)) # get the bmm specific information bmmodel <- object$bmm$model @@ -128,15 +128,6 @@ print.bmmsummary <- function(x, digits = 2, color = getOption('bmm.color_summary invisible(x) } -rename_mu_smry <- function(x, mu_pars) { - for (i in seq_along(x)) { - if (is.data.frame(x[[i]])) { - rownames(x[[i]])[rownames(x[[i]]) %in% mu_pars] <- paste0("mu_", mu_pars) - } - } - x -} - select_pars <- function(x) { model_pars <- names(x$model$parameters) provided_dpars <- names(x$formula)[!is_nl(x$formula)] diff --git a/R/utils.R b/R/utils.R index 901488a..fd94d73 100644 --- a/R/utils.R +++ b/R/utils.R @@ -716,3 +716,131 @@ check_rds_file <- function(file) { `%||%` <- function(a, b) { if (!is.null(a)) a else b } + +#' Some default settings I liked +#' +#' @export +#' @return previous options +#' @keywords internal +vens_options <- function() { + op <- options(mc.cores = parallel::detectCores(), + bmm.sort_data = TRUE, + cmdstanr_write_stan_file_dir = "local/cmdstanr") + invisible(op) +} + + +# test vectors for near equality +near <- function(x, y, tol = sqrt(.Machine$double.eps)) { + abs(x - y) < tol +} + + +add_links <- function(x, ...) { + UseMethod("add_links") +} + +#' @export +add_links.bmmfit <- function(x, ...) { + x$bmm$model <- add_links(x$bmm$model) + x +} + +#' @export +add_links.bmmodel <- function(x, ...) { + model_name <- class(x)[length(class(x))] + new_model <- get_model(model_name)() + x$links <- new_model$links + x +} + +# object - either a brmsformula or a bmmodel +#' @export +add_links.brmsprior <- function(x, object, family, ...) { + # preprocess the links to an appropriate format + links <- get_links(object, family) + links <- unlist(links) + info <- strsplit(names(links), ".", fixed = TRUE) + if (length(info[[1]]) == 2) { + resp <- ulapply(info, function(x) x[1]) + dpar <- ulapply(info, function(x) x[2]) + } else { + resp <- rep("", length(info)) + dpar <- ulapply(info, function(x) x[1]) + } + links_df <- data.frame(resp = resp, dpar = dpar, link = unname(links)) + + x <- add_mu(x) + + # some parameters in brms are not in dpar, but in "class" if they are not predicted + in_class <- x$class %in% dpar & x$dpar == "" + dpar_old <- x$dpar + x$dpar <- ifelse(in_class, x$class, x$dpar) + # + + x <- suppressMessages(dplyr::left_join(x, links_df)) + x$dpar <- dpar_old + x$link[is.na(x$link)] <- "identity" + + x +} + +# extract links from an object +get_links <- function(x, ...) { + UseMethod("get_links") +} + +#' @export +get_links.default <- function(x, ...) { + x +} + +#' @export +get_links.brmsformula <- function(x, ...) { + dots <- list(...) + if (!is.null(dots$family)) { + x <- brms::brmsformula(x, family = dots$family) + } + x <- brms::brmsterms(x) + dpars <- x$family$dpars + links <- setNames(rep("identity", length(dpars)), dpars) + links <- as.list(links) + links_pred <- lapply(x$dpars, function(x) x$family$link) + links[names(links_pred)] <- links_pred + if (conv_cats_dpars(x)) { + links[grepl("^mu", names(links))] <- x$family$link + } + links +} + +#' @export +get_links.mvbrmsformula <- function(x, ...) { + lapply(x$forms, get_links, ...) +} + +#' @export +get_links.bmmodel <- function(x, ...) { + x$links +} + + +# add the missing mu paramereter from various brms objects +add_mu <- function(x, ...) { + UseMethod("add_mu") +} + +#' @export +add_mu.brmssummary <- function(x, mu_pars,...) { + for (i in seq_along(x)) { + if (is.data.frame(x[[i]])) { + rownames(x[[i]])[rownames(x[[i]]) %in% mu_pars] <- paste0("mu_", mu_pars) + } + } + x +} + +#' @export +add_mu.brmsprior <- function(x, ...) { + x$dpar <- ifelse(x$dpar == "" & x$nlpar == "" & (x$class %in% c("b", "Intercept")), "mu", x$dpar) + x +} diff --git a/_pkgdown.yml b/_pkgdown.yml index f06852f..5a23288 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -58,6 +58,10 @@ reference: desc: "Available datasets for fitting the models" - contents: - has_keyword("dataset") + - title: "Plotting" + desc: "Misc functions for plotting" + - contents: + - has_keyword("plot") - title: "Developers' corner" desc: "Functions to assist in developing new models" - contents: diff --git a/man/SDMdist.Rd b/man/SDMdist.Rd index 301f306..f6cd3d5 100644 --- a/man/SDMdist.Rd +++ b/man/SDMdist.Rd @@ -90,26 +90,26 @@ probability mass is below the mean of the response distribution. \examples{ # plot the density of the SDM distribution x <- seq(-pi,pi,length.out=10000) -plot(x,dsdm(x,0,2,3),type="l", xlim=c(-pi,pi),ylim=c(0,1), - xlab="Angle error (radians)", - ylab="density", - main="SDM density") -lines(x,dsdm(x,0,9,1),col="red") -lines(x,dsdm(x,0,2,8),col="green") -legend("topright",c("c=2, kappa=3.0, mu=0", - "c=9, kappa=1.0, mu=0", - "c=2, kappa=8, mu=1"), - col=c("black","red","green"),lty=1, cex=0.8) +plot(x, dsdm(x,0,2,3), type="l", xlim = c(-pi,pi), ylim = c(0,1), + xlab = "Angle error (radians)", + ylab = "density", + main = "SDM density") +lines(x, dsdm(x,0,9,1), col="red") +lines(x, dsdm(x,0,2,8), col="green") +legend("topright", c("c=2, kappa=3.0, mu=0", + "c=9, kappa=1.0, mu=0", + "c=2, kappa=8, mu=1"), + col=c("black","red","green"), lty = 1, cex = 0.8) # plot the cumulative distribution function of the SDM distribution p <- psdm(x, mu = 0, c = 3.1, kappa = 5) -plot(x,p,type="l") +plot(x, p, type = "l") # generate random deviates from the SDM distribution and overlay the density r <- rsdm(10000, mu = 0, c = 3.1, kappa = 5) d <- dsdm(x, mu = 0, c = 3.1, kappa = 5) -hist(r, breaks=60, freq=FALSE) -lines(x,d,type="l", col="red") +hist(r, breaks = 60, freq = FALSE) +lines(x, d, type = "l", col = "red") } \references{ diff --git a/man/k2sd.Rd b/man/k2sd.Rd index 0a7f964..a8b0e21 100644 --- a/man/k2sd.Rd +++ b/man/k2sd.Rd @@ -28,8 +28,8 @@ SDs <- k2sd(kappas) SDs_degress <- SDs * 180 / pi # plot the relationship between kappa and circular SD -plot(kappas,SDs) -plot(kappas,SDs_degress) +plot(kappas, SDs) +plot(kappas, SDs_degress) } \keyword{transform} diff --git a/man/plot-distribution.Rd b/man/plot-distribution.Rd new file mode 100644 index 0000000..38100ca --- /dev/null +++ b/man/plot-distribution.Rd @@ -0,0 +1,121 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot.R +\name{plot-distribution} +\alias{plot-distribution} +\alias{plot.distribution} +\alias{plot.character} +\alias{plot.brmsprior} +\title{Plot a distribution object from the \code{distributional} package} +\usage{ +\method{plot}{distribution}( + x, + ..., + type = "pdf", + labels = NULL, + facets = FALSE, + stat_slab_control = list() +) + +\method{plot}{character}( + x, + ..., + type = "pdf", + labels = NULL, + facets = FALSE, + stat_slab_control = list(), + packages = c("brms", "extraDistr", "bmm", "ggdist") +) + +\method{plot}{brmsprior}( + x, + formula = NULL, + links = NULL, + transform = FALSE, + type = "pdf", + facets = TRUE, + stat_slab_control = list(), + packages = c("brms", "extraDistr", "bmm", "ggdist"), + ... +) +} +\arguments{ +\item{x}{Several options. One of: +\itemize{ +\item An object of class \code{distribution} from the \code{distributional} package +\item A character vector of distribution names that can be parsed by +\code{ggdist::parse_dist()} of the type \code{distname(param1 = value1, param2 = value2)} +\item A \code{brmsprior} object produced by \link[brms:set_prior]{brms::set_prior}, \link[brms:default_prior]{brms::default_prior} or +\code{\link[brms:prior_summary.brmsfit]{brms::prior_summary()}} +}} + +\item{...}{additional distribution objects to plot.} + +\item{type}{The type of plot. One of "pdf" or "cdf".} + +\item{labels}{A character vector of labels for the distributions. If \code{NULL}, +the labels are automatically generated from the distribution objects. If a +character vector is provided, it must be the same length as the number of +distributions being plotted. Default is \code{NULL}.} + +\item{facets}{Logical. If \code{TRUE}, the distributions are plotted in separate +facets. If \code{FALSE}, all distributions are plotted in the same plot. Default +is \code{FALSE}.} + +\item{stat_slab_control}{A list of additional arguments passed to +\code{ggdist::stat_slab()}} + +\item{packages}{A character vector of package names to search for +distributions} + +\item{formula}{A \code{brmsformula} object. Only needed for plotting \code{brmsprior} +objects if transformation of the priors is wanted via the specified links +in the \code{brmsfamily}. The \code{brmsformula} must contain the \code{brmsfamily}} + +\item{links}{A named list of links for the parameters. This is an alternative +way to specify the links for the parameters. Only one of \code{links} or +\code{formula} can be provided.} + +\item{transform}{Logical. If \code{TRUE}, the priors are transformed according to +the inverse of the link functions present either in the \code{brmsformula} or +the \code{links} argument. Default is \code{FALSE}.} +} +\value{ +A ggplot object +} +\description{ +Uses \code{ggplot2} and \code{ggdist} to create a line plot of the probability density +function (pdf) or cumulative distribution function (cdf) of a distribution +object from the \code{distributional} package or a character vector of +distribution names of the type \code{distname(param1 = value1, param2 = value2)} +} +\details{ +These are generic methods for plotting distributions specified as +\itemize{ +\item \code{distribution} objects from the \code{distributional} package +\item character vectors of distribution names that can be parsed by \code{\link[ggdist:parse_dist]{ggdist::parse_dist()}} +\item \code{brmsprior} objects produced by \link[brms:set_prior]{brms::set_prior}, \link[brms:default_prior]{brms::default_prior} or +\code{\link[brms:prior_summary.brmsfit]{brms::prior_summary()}} +} + +For \code{brmsprior} objects, you can also plot the transformations of the priors +to the native scale of the parameters by setting \code{transform = TRUE}. This +will transform the priors according to the inverse of the link functions +present either in the \code{brmsformula} or the \code{links} argument. For example, if +the prior is specified as \code{normal(0, 1)} and the link function is \code{log}, as +it is often the case for parameters that need to be strictly positive, the +prior will be transformed to \code{lognormal(0, 1)} by applying \code{exp()} to the +distribution object. +} +\examples{ +x <- distributional::dist_normal(mean = 0, sd = c(0.5,1,1.2)) +plot(x) + +x2 <- distributional::dist_wrap("sdm", mu = 0, c = c(1:10), kappa = 3) +plot(x2) + +plot('normal(0, 1)', 'normal(0, 2)', type = 'cdf') +plot(c('normal(0, 1)', 'student_t(3,0,2.5)'), facets = T) + +plot(distributional::dist_beta(1:10, 1)) + theme(legend.position = 'right') +} +\keyword{plot} diff --git a/man/stat_slab2.Rd b/man/stat_slab2.Rd new file mode 100644 index 0000000..c35dfeb --- /dev/null +++ b/man/stat_slab2.Rd @@ -0,0 +1,55 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot.R +\name{stat_slab2} +\alias{stat_slab2} +\title{ggdist::stat_slab() with nice defaults} +\usage{ +stat_slab2( + ..., + subguide = ggdist::subguide_outside(title = "density"), + subscale = if (packageVersion("ggdist") > "3.3.2") { + + ggdist::subscale_thickness(expand = expansion(c(0, 0.05))) + } else { + NULL + }, + normalize = "groups" +) +} +\arguments{ +\item{...}{Additional arguments passed to ggdist::stat_slab()} + +\item{subguide}{Sub-guide used to annotate the \code{thickness} scale. One of: +\itemize{ +\item A function that takes a \code{scale} argument giving a \link[ggplot2:ggplot2-ggproto]{ggplot2::Scale} +object and an \code{orientation} argument giving the orientation of the +geometry and then returns a \link[grid:grid.grob]{grid::grob} that will draw the axis +annotation, such as \code{\link[ggdist:subguide_axis]{subguide_axis()}} (to draw a traditional axis) or +\code{\link[ggdist:subguide_none]{subguide_none()}} (to draw no annotation). See \code{\link[ggdist:subguide_axis]{subguide_axis()}} +for a list of possibilities and examples. +\item A string giving the name of such a function when prefixed +with \code{"subguide"}; e.g. \code{"axis"} or \code{"none"}. +}} + +\item{normalize}{How to normalize heights of functions input to the \code{thickness} aesthetic. One of: +\itemize{ +\item \code{"all"}: normalize so that the maximum height across all data is \code{1}. +\item \code{"panels"}: normalize within panels so that the maximum height in each panel is \code{1}. +\item \code{"xy"}: normalize within the x/y axis opposite the \code{orientation} of this geom so +that the maximum height at each value of the opposite axis is \code{1}. +\item \code{"groups"}: normalize within values of the opposite axis and within each +group so that the maximum height in each group is \code{1}. +\item \code{"none"}: values are taken as is with no normalization (this should probably +only be used with functions whose values are in [0,1], such as CDFs). +} +For a comprehensive discussion and examples of slab scaling and normalization, see the +\href{https://mjskay.github.io/ggdist/articles/thickness.html}{\code{thickness} scale article}.} +} +\value{ +a stat_slab() object +} +\description{ +A wrapper for ggdist::stat_slab() that sets default subguides and scales to +display the density on the y axis. +} +\keyword{plot} diff --git a/man/theme_dist.Rd b/man/theme_dist.Rd new file mode 100644 index 0000000..98d456f --- /dev/null +++ b/man/theme_dist.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot.R +\name{theme_dist} +\alias{theme_dist} +\title{ggplot theme for plotting distributions} +\usage{ +theme_dist() +} +\value{ +A ggplot theme object +} +\description{ +Based on \code{theme_ggdist()} from the \code{ggdist} package, but with the y-axis +removed and margins adjusted to make space for a custom y axis +} +\keyword{plot} diff --git a/man/vens_options.Rd b/man/vens_options.Rd new file mode 100644 index 0000000..e2eff6b --- /dev/null +++ b/man/vens_options.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{vens_options} +\alias{vens_options} +\title{Some default settings I liked} +\usage{ +vens_options() +} +\value{ +previous options +} +\description{ +Some default settings I liked +} +\keyword{internal}