Skip to content

Commit

Permalink
plot prior methods
Browse files Browse the repository at this point in the history
  • Loading branch information
venpopov committed Apr 6, 2024
1 parent bdbc6eb commit 21e67ad
Show file tree
Hide file tree
Showing 11 changed files with 198 additions and 113 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Suggests:
remotes,
waldo,
fansi,
cmdstanr (>= 0.7.0)
cmdstanr (>= 0.7.0),
extraDistr
Config/testthat/edition: 3
Imports:
magrittr,
Expand All @@ -48,7 +49,7 @@ Imports:
ggplot2,
distributional,
ggdist,
extraDistr
rlang
URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/
BugReports: https://github.com/venpopov/bmm/issues
Additional_repositories:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ S3method(fit_info,brmsfit)
S3method(fit_info,brmsfit_list)
S3method(get_links,bmmodel)
S3method(get_links,brmsformula)
S3method(get_links,default)
S3method(get_links,mvbrmsformula)
S3method(identical,bmmformula)
S3method(identical,brmsformula)
Expand Down
26 changes: 13 additions & 13 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,26 @@
#' @examples
#' # plot the density of the SDM distribution
#' x <- seq(-pi,pi,length.out=10000)
#' plot(x,dsdm(x,0,2,3),type="l", xlim=c(-pi,pi),ylim=c(0,1),
#' xlab="Angle error (radians)",
#' ylab="density",
#' main="SDM density")
#' lines(x,dsdm(x,0,9,1),col="red")
#' lines(x,dsdm(x,0,2,8),col="green")
#' legend("topright",c("c=2, kappa=3.0, mu=0",
#' "c=9, kappa=1.0, mu=0",
#' "c=2, kappa=8, mu=1"),
#' col=c("black","red","green"),lty=1, cex=0.8)
#' plot(x, dsdm(x,0,2,3), type="l", xlim = c(-pi,pi), ylim = c(0,1),
#' xlab = "Angle error (radians)",
#' ylab = "density",
#' main = "SDM density")
#' lines(x, dsdm(x,0,9,1), col="red")
#' lines(x, dsdm(x,0,2,8), col="green")
#' legend("topright", c("c=2, kappa=3.0, mu=0",
#' "c=9, kappa=1.0, mu=0",
#' "c=2, kappa=8, mu=1"),
#' col=c("black","red","green"), lty = 1, cex = 0.8)
#'
#' # plot the cumulative distribution function of the SDM distribution
#' p <- psdm(x, mu = 0, c = 3.1, kappa = 5)
#' plot(x,p,type="l")
#' plot(x, p, type = "l")
#'
#' # generate random deviates from the SDM distribution and overlay the density
#' r <- rsdm(10000, mu = 0, c = 3.1, kappa = 5)
#' d <- dsdm(x, mu = 0, c = 3.1, kappa = 5)
#' hist(r, breaks=60, freq=FALSE)
#' lines(x,d,type="l", col="red")
#' hist(r, breaks = 60, freq = FALSE)
#' lines(x, d, type = "l", col = "red")
#'
dsdm <- function(x, mu = 0, c = 3, kappa = 3.5, log = FALSE,
parametrization = "sqrtexp") {
Expand Down
4 changes: 2 additions & 2 deletions R/helpers-parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#' SDs_degress <- SDs * 180 / pi
#'
#' # plot the relationship between kappa and circular SD
#' plot(kappas,SDs)
#' plot(kappas,SDs_degress)
#' plot(kappas, SDs)
#' plot(kappas, SDs_degress)
#'
k2sd <- function(K) {
S <- matrix(0, 1, length(K))
Expand Down
26 changes: 19 additions & 7 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,21 @@ summarise_default_prior <- function(prior_list) {


# preprocess a brmsprior object for plotting
prep_brmsprior <- function(x, formula = NULL) {
prep_brmsprior <- function(x, formula = NULL, links = NULL) {
stopif(!brms::is.brmsprior(x), "x must be a brmsprior object")
# remove empty priors
priors <- x[x$prior != "",]
if (!is.null(formula)) {
stopif(!is.null(links), "Either formula or links should be provided")
stopif(!is_brmsformula(formula), "formula must be a brmsformula object")
stopif(is.null(formula$family), "formula must contain a brmsfamily object")
link_source <- formula
} else {
stopif(!is.null(links) && (!is.list(links) || is.null(names(links))),
"links must be a named list")
link_source <- links
}

# add mu to dpar if dpar is empty
priors <- x[x$prior != "",]
priors <- add_mu(priors)

# remove constant priors
priors <- priors[!grepl("constant", priors$prior),]

# rename priors to match distribution functions
Expand All @@ -297,5 +303,11 @@ prep_brmsprior <- function(x, formula = NULL) {
priors$par <- ifelse(priors$nlpar != "", priors$nlpar, priors$dpar)

# remove corr parameters as it's too complicated to visualize them
priors[priors$class != "cor",]
priors <- priors[priors$class != "cor",]

# add a column with the corresponding parameter link
if (!("link" %in% names(x))) {
priors <- add_links(priors, object = link_source)
}
priors
}
8 changes: 4 additions & 4 deletions R/math.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ logit <- function(p) {
}

inv_logit <- function(x) {
exp(x) / (1 + exp(x))
1 / (1 + exp(-1*x))
}

cloglog <- function (x) {
log(-log1p(-x))
}

inv_cloglog <- function(x) {
1 - exp(-exp(x))
1 - exp(-1*exp(x))
}

log_expm1 <- function (x) {
Expand All @@ -29,11 +29,11 @@ log1p_exp_custom <- function(x, mu = 0, a = 0, b = 1) {
ifelse(out < Inf, out, x)
}

softfit <- function(x) {
softit <- function(x) {
log_expm1(x/(1-x))
}

inv_softfit <- function(x) {
inv_softit <- function(x) {
y <- log1p_exp(x)
y/(1 + y)
}
118 changes: 78 additions & 40 deletions R/plot.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,3 @@


#' @import ggplot2
#' @rdname plot-distribution
#' @export
plot.brmsprior <- function(x, formula = NULL, type = 'pdf', facets = TRUE,
stat_slab_control = list(),
packages = c("brms","extraDistr","bmm"), ...) {
prior <- prep_brmsprior(x, formula = formula)
prior <- ggdist::parse_dist(prior, package = pkg_search_env(packages))
prior$labels <- glue("{prior$resp}_{prior$class}_{prior$par}_{prior$coef}_{prior$group}")
prior$labels <- gsub("_+", "_", prior$labels)
prior$labels <- gsub("(^_|_$)", "", prior$labels)
prior$labels <- paste0(prior$labels, " ~ ", format(prior$.dist_obj))
prior <- prior[order(prior$par, prior$class),]
prior$labels <- factor(prior$labels, levels = unique(prior$labels))
plot(prior$.dist_obj, type = type, facets = facets, labels = prior$labels,
stat_slab_control = stat_slab_control)
}


#' Plot a distribution object from the `distributional` package
#'
#' Uses `ggplot2` and `ggdist` to create a line plot of the probability density
Expand All @@ -31,26 +10,47 @@ plot.brmsprior <- function(x, formula = NULL, type = 'pdf', facets = TRUE,
#' @param x Several options. One of:
#' - An object of class `distribution` from the `distributional` package
#' - A character vector of distribution names that can be parsed by
#' `ggdist::parse_dist()` of the type `distname(param1 = value1, param2 = value2)`
#' `ggdist::parse_dist()` of the type `distname(param1 = value1, param2 =
#' value2)`
#' - A `brmsprior` object produced by [brms::set_prior], [brms::default_prior] or
#' [brms::prior_summary()]
#' [brms::prior_summary()]
#' @param ... additional distribution objects to plot.
#' @param formula A `brmsformula` object. Only needed for plotting `brmsprior`
#' objects if transformation of the priors is wanted via the specified links in
#' the `brmsfamily`. The `brmsformula` must contain the `brmsfamily`
#' objects if transformation of the priors is wanted via the specified links
#' in the `brmsfamily`. The `brmsformula` must contain the `brmsfamily`
#' @param links A named list of links for the parameters. This is an alternative
#' way to specify the links for the parameters. Only one of `links` or
#' `formula` can be provided.
#' @param transform Logical. If `TRUE`, the priors are transformed according to
#' the inverse of the link functions present either in the `brmsformula` or
#' the `links` argument. Default is `FALSE`.
#' @param type The type of plot. One of "pdf" or "cdf".
#' @param labels A character vector of labels for the distributions. If `NULL`,
#' the labels are automatically generated from the distribution objects. If a
#' character vector is provided, it must be the same length as the number of
#' distributions being plotted. Default is `NULL`.
#' the labels are automatically generated from the distribution objects. If a
#' character vector is provided, it must be the same length as the number of
#' distributions being plotted. Default is `NULL`.
#' @param facets Logical. If `TRUE`, the distributions are plotted in separate
#' facets. If `FALSE`, all distributions are plotted in the same plot. Default is
#' `FALSE`.
#' @param stat_slab_control A list of additional arguments passed to `ggdist::stat_slab()`
#' @param packages A character vector of package names to search for distributions
#' facets. If `FALSE`, all distributions are plotted in the same plot. Default
#' is `FALSE`.
#' @param stat_slab_control A list of additional arguments passed to
#' `ggdist::stat_slab()`
#' @param packages A character vector of package names to search for
#' distributions
#'
#' @details These are generic methods for plotting distributions specified as
#' - `distribution` objects from the `distributional` package
#' - character vectors of distribution names that can be parsed by [ggdist::parse_dist()]
#' - `brmsprior` objects produced by [brms::set_prior], [brms::default_prior] or
#' [brms::prior_summary()]
#'
#' @details By default all distributions are plotted as different lines in the
#' same plot. If you want to plot them in separate facets, set `facets = TRUE`.
#' For `brmsprior` objects, you can also plot the transformations of the priors
#' to the native scale of the parameters by setting `transform = TRUE`. This
#' will transform the priors according to the inverse of the link functions
#' present either in the `brmsformula` or the `links` argument. For example, if
#' the prior is specified as `normal(0, 1)` and the link function is `log`, as
#' it is often the case for parameters that need to be strictly positive, the
#' prior will be transformed to `lognormal(0, 1)` by applying `exp()` to the
#' distribution object.
#'
#' @return A ggplot object
#' @examples
Expand All @@ -61,13 +61,16 @@ plot.brmsprior <- function(x, formula = NULL, type = 'pdf', facets = TRUE,
#' plot(x2)
#'
#' plot('normal(0, 1)', 'normal(0, 2)', type = 'cdf')
#' plot(c('normal(0, 1)', 'student_t(3,0,2.5))', facets = T)
#' plot(c('normal(0, 1)', 'student_t(3,0,2.5)'), facets = T)
#'
#' plot(dist_beta(1:10, 1)) + theme(legend.position = 'right')
#' plot(distributional::dist_beta(1:10, 1)) + theme(legend.position = 'right')
#' @keywords plot
#' @export
plot.distribution <- function(x, ..., type = 'pdf', labels = NULL, facets = FALSE,
stat_slab_control = list()) {
if (!requireNamespace('extraDistr')) {
stop2("The 'extraDistr' package is required for plotting distributions")
}
dots <- list(...)
if (length(dots) > 0) {
stopif(any(!sapply(dots, inherits, what = "distribution")),
Expand All @@ -90,7 +93,10 @@ plot.distribution <- function(x, ..., type = 'pdf', labels = NULL, facets = FALS
stat_slab_args$fill <- stat_slab_args$fill %||% NA
stat_slab_args$subguide <- stat_slab_args$subguide %||%
ggdist::subguide_outside(title = "density")
stat_slab_args$scale <- stat_slab_args$scales %||% 1
stat_slab_args$subscale <- stat_slab_args$subscale %||%
ggdist::subscale_thickness(expand = expansion(c(0, 0.1)))
stat_slab_args$scale <- stat_slab_args$scale %||% 1
stat_slab_args$p_limits <- stat_slab_args$p_limits %||% c(0.001, 0.999)
if (facets) {
stat_slab_args$color <- stat_slab_args$color %||% "grey30"
stat_slab_args$normalize <- stat_slab_args$normalize %||% "groups"
Expand All @@ -114,7 +120,7 @@ plot.distribution <- function(x, ..., type = 'pdf', labels = NULL, facets = FALS
#' @export
plot.character <- function(x, ..., type = 'pdf', labels = NULL, facets = FALSE,
stat_slab_control = list(),
packages = c("ggdist","brms","extraDistr","bmm")) {
packages = c("brms","extraDistr","bmm","ggdist")) {
search_env <- pkg_search_env(packages)
dots <- list(...)
dists <- unlist(c(x, dots))
Expand All @@ -124,11 +130,42 @@ plot.character <- function(x, ..., type = 'pdf', labels = NULL, facets = FALSE,
stat_slab_control = stat_slab_control)
}


#' @import ggplot2
#' @rdname plot-distribution
#' @export
plot.brmsprior <- function(x, formula = NULL, links = NULL, transform = FALSE,
type = 'pdf', facets = TRUE, stat_slab_control = list(),
packages = c("brms","extraDistr","bmm","ggdist"), ...) {
prior <- prep_brmsprior(x, formula = formula, links = links)
prior <- ggdist::parse_dist(prior, package = pkg_search_env(packages))

if (transform) {
stopif(is.null(formula) && is.null(links) && !("link" %in% names(prior)),
glue("If transform = TRUE, either formula or links must be provided, \\
or there must be a 'link' column in the brmsprior object"))
for (i in 1:nrow(prior)) {
prior$.dist_obj[i] <- inv_link(prior$.dist_obj[[i]], prior$link[[i]])
}
}


prior$labels <- glue("{prior$resp}_{prior$class}_{prior$par}_{prior$coef}_{prior$group}")
prior$labels <- gsub("_+", "_", prior$labels)
prior$labels <- gsub("(^_|_$)", "", prior$labels)
prior$labels <- paste0(prior$labels, " ~ ", format(prior$.dist_obj))
prior <- prior[order(prior$par, prior$class),]
prior$labels <- factor(prior$labels, levels = unique(prior$labels))
plot(prior$.dist_obj, type = type, facets = facets, labels = prior$labels,
stat_slab_control = stat_slab_control)
}


pkg_search_env <- function(packages) {
stopif(any(!is.character(packages)))
search_env <- lapply(packages, function(x) asNamespace(x))
search_env <- c(search_env, globalenv())
search_env <- globalenv()
search_env <- c(search_env, rlang::search_envs())
search_env <- c(search_env, lapply(packages, function(x) asNamespace(x)))
search_env <- unname(search_env)
as.environment(do.call(c, lapply(search_env, as.list)))
}
Expand Down Expand Up @@ -168,3 +205,4 @@ theme_dist <- function() {
axis.ticks.y = element_blank(),
axis.title.y = element_blank())
}

11 changes: 9 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ add_links.brmsprior <- function(x, object, family, ...) {
links <- get_links(object, family)
links <- unlist(links)
info <- strsplit(names(links), ".", fixed = TRUE)
if (is(formula, 'mvbrmsformula')) {
if (length(info[[1]]) == 2) {
resp <- ulapply(info, function(x) x[1])
dpar <- ulapply(info, function(x) x[2])
} else {
Expand All @@ -776,6 +776,8 @@ add_links.brmsprior <- function(x, object, family, ...) {
in_class <- x$class %in% dpar & x$dpar == ""
dpar_old <- x$dpar
x$dpar <- ifelse(in_class, x$class, x$dpar)
#

x <- suppressMessages(dplyr::left_join(x, links_df))
x$dpar <- dpar_old
x$link[is.na(x$link)] <- "identity"
Expand All @@ -788,11 +790,16 @@ get_links <- function(x, ...) {
UseMethod("get_links")
}

#' @export
get_links.default <- function(x, ...) {
x
}

#' @export
get_links.brmsformula <- function(x, ...) {
dots <- list(...)
if (!is.null(dots$family)) {
x <- brmsformula(x, family = dots$family)
x <- brms::brmsformula(x, family = dots$family)
}
x <- brms::brmsterms(x)
dpars <- x$family$dpars
Expand Down
Loading

0 comments on commit 21e67ad

Please sign in to comment.