Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot method for default priors #205

Draft
wants to merge 13 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
remotes
waldo
fansi
extraDistr

- name: Build Cmdstan
run: |
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
remotes
waldo
fansi
extraDistr
any::pkgdown
local::.

Expand Down
10 changes: 7 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ Suggests:
rmarkdown,
bookdown,
tidybayes,
ggplot2,
mixtur,
ggthemes,
cowplot,
stringr,
remotes,
waldo,
fansi,
cmdstanr (>= 0.7.0)
cmdstanr (>= 0.7.0),
extraDistr
Config/testthat/edition: 3
Imports:
magrittr,
Expand All @@ -45,7 +45,11 @@ Imports:
matrixStats,
crayon,
methods,
fs
fs,
ggplot2,
distributional,
ggdist,
rlang
URL: https://github.com/venpopov/bmm, https://venpopov.github.io/bmm/
BugReports: https://github.com/venpopov/bmm/issues
Additional_repositories:
Expand Down
14 changes: 14 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ S3method("[",bmmformula)
S3method("[<-",bmmformula)
S3method(add_links,bmmfit)
S3method(add_links,bmmodel)
S3method(add_links,brmsprior)
S3method(add_mu,brmsprior)
S3method(add_mu,brmssummary)
S3method(bmf2bf,bmmodel)
S3method(check_data,bmmodel)
S3method(check_data,default)
Expand Down Expand Up @@ -33,6 +36,10 @@ S3method(configure_prior,mixture3p)
S3method(default_prior,bmmformula)
S3method(fit_info,brmsfit)
S3method(fit_info,brmsfit_list)
S3method(get_links,bmmodel)
S3method(get_links,brmsformula)
S3method(get_links,default)
S3method(get_links,mvbrmsformula)
S3method(identical,bmmformula)
S3method(identical,brmsformula)
S3method(identical,default)
Expand All @@ -42,6 +49,9 @@ S3method(is_constant,default)
S3method(is_nl,bmmformula)
S3method(is_nl,default)
S3method(model_info,bmmodel)
S3method(plot,brmsprior)
S3method(plot,character)
S3method(plot,distribution)
S3method(postprocess_brm,bmmodel)
S3method(postprocess_brm,default)
S3method(postprocess_brm,sdm)
Expand Down Expand Up @@ -110,9 +120,13 @@ export(sdm)
export(sdmSimple)
export(softmax)
export(softmaxinv)
export(stat_slab2)
export(supported_models)
export(theme_dist)
export(use_model_template)
export(vens_options)
export(wrap)
import(ggplot2)
import(stats)
importFrom(brms,default_prior)
importFrom(brms,restructure)
Expand Down
75 changes: 75 additions & 0 deletions R/brms-misc.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Functions in this file are copies of `brms` internal functions that we are
# not allowed to import. All copyrights belong to the original author

# check if x is a try-error resulting from try()
is_try_error <- function(x) {
inherits(x, "try-error")
Expand Down Expand Up @@ -282,3 +285,75 @@ rename <- function(x, pattern = NULL, replacement = NULL,
}
out
}


# apply a link function
# @param x an array of arbitrary dimension
# @param link character string defining the link
link <- function(x, link) {
switch(link,
identity = x,
log = log(x),
logm1 = brms::logm1(x),
log1p = log1p(x),
inverse = 1 / x,
sqrt = sqrt(x),
"1/mu^2" = 1 / x^2,
tan_half = tan(x / 2),
logit = logit(x),
probit = qnorm(x),
cauchit = qcauchy(x),
cloglog = cloglog(x),
probit_approx = qnorm(x),
softplus = log_expm1(x),
squareplus = (x^2 - 1) / x,
softit = softit(x),
stop2("Link '", link, "' is not supported.")
)
}

# apply an inverse link function
# @param x an array of arbitrary dimension
# @param link a character string defining the link
inv_link <- function(x, link) {
switch(link,
identity = x,
log = exp(x),
logm1 = brms::expp1(x),
log1p = expm1(x),
inverse = 1 / x,
sqrt = x^2,
"1/mu^2" = 1 / sqrt(x),
tan_half = 2 * atan(x),
logit = inv_logit(x),
probit = pnorm(x),
cauchit = pcauchy(x),
cloglog = inv_cloglog(x),
probit_approx = pnorm(x),
softplus = log1p_exp(x),
squareplus = (x + sqrt(x^2 + 4)) / 2,
softit = inv_softit(x),
stop2("Link '", link, "' is not supported.")
)
}


conv_cats_dpars <- function (family) {
is_categorical(family) || is_multinomial(family) || is_simplex(family)
}

is_categorical <- function(family) {
'categorical' %in% family$specials
}

is_multinomial <- function(family) {
'multinomial' %in% family$specials
}

is_simplex <- function(family) {
'simplex' %in% family$specials
}

ulapply <- function (X, FUN, ..., recursive = TRUE, use.names = TRUE) {
unlist(lapply(X, FUN, ...), recursive, use.names)
}
39 changes: 24 additions & 15 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,26 @@
#' @examples
#' # plot the density of the SDM distribution
#' x <- seq(-pi,pi,length.out=10000)
#' plot(x,dsdm(x,0,2,3),type="l", xlim=c(-pi,pi),ylim=c(0,1),
#' xlab="Angle error (radians)",
#' ylab="density",
#' main="SDM density")
#' lines(x,dsdm(x,0,9,1),col="red")
#' lines(x,dsdm(x,0,2,8),col="green")
#' legend("topright",c("c=2, kappa=3.0, mu=0",
#' "c=9, kappa=1.0, mu=0",
#' "c=2, kappa=8, mu=1"),
#' col=c("black","red","green"),lty=1, cex=0.8)
#' plot(x, dsdm(x,0,2,3), type="l", xlim = c(-pi,pi), ylim = c(0,1),
#' xlab = "Angle error (radians)",
#' ylab = "density",
#' main = "SDM density")
#' lines(x, dsdm(x,0,9,1), col="red")
#' lines(x, dsdm(x,0,2,8), col="green")
#' legend("topright", c("c=2, kappa=3.0, mu=0",
#' "c=9, kappa=1.0, mu=0",
#' "c=2, kappa=8, mu=1"),
#' col=c("black","red","green"), lty = 1, cex = 0.8)
#'
#' # plot the cumulative distribution function of the SDM distribution
#' p <- psdm(x, mu = 0, c = 3.1, kappa = 5)
#' plot(x,p,type="l")
#' plot(x, p, type = "l")
#'
#' # generate random deviates from the SDM distribution and overlay the density
#' r <- rsdm(10000, mu = 0, c = 3.1, kappa = 5)
#' d <- dsdm(x, mu = 0, c = 3.1, kappa = 5)
#' hist(r, breaks=60, freq=FALSE)
#' lines(x,d,type="l", col="red")
#' hist(r, breaks = 60, freq = FALSE)
#' lines(x, d, type = "l", col = "red")
#'
dsdm <- function(x, mu = 0, c = 3, kappa = 3.5, log = FALSE,
parametrization = "sqrtexp") {
Expand Down Expand Up @@ -116,7 +116,7 @@ psdm <- function(q, mu = 0, c = 3, kappa = 3.5, lower.tail = TRUE, log.p = FALSE

pi <- base::pi
pi2 <- 2 * pi
q <- (q + pi) %% pi2
q <- ifelse(q == pi, pi2, (q + pi) %% pi2)
mu <- (mu + pi) %% pi2
lower.bound <- (lower.bound + pi) %% pi2

Expand Down Expand Up @@ -147,7 +147,16 @@ psdm <- function(q, mu = 0, c = 3, kappa = 3.5, lower.tail = TRUE, log.p = FALSE
#' @rdname SDMdist
#' @export
qsdm <- function(p, mu=0, c=3, kappa=3.5, parametrization = "sqrtexp") {
.NotYetImplemented()
p <- ifelse(near(p,1), p-1e-7, p)

.qsdm <- function(p, mu, c, kappa, parametrization) {
uniroot(function(x) psdm(x, mu = mu, c = c, kappa = kappa,
lower.tail = TRUE, log.p = FALSE,
lower.bound = -pi, parametrization = parametrization) - p,
interval = c(-pi, pi))$root
}
.qsdm_v <- Vectorize(.qsdm)
.qsdm_v(p, mu, c, kappa, parametrization)
}

#' @rdname SDMdist
Expand Down
4 changes: 2 additions & 2 deletions R/helpers-parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#' SDs_degress <- SDs * 180 / pi
#'
#' # plot the relationship between kappa and circular SD
#' plot(kappas,SDs)
#' plot(kappas,SDs_degress)
#' plot(kappas, SDs)
#' plot(kappas, SDs_degress)
#'
k2sd <- function(K) {
S <- matrix(0, 1, length(K))
Expand Down
44 changes: 41 additions & 3 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,24 @@
#' @export
default_prior.bmmformula <- function(object, data, model, formula = object, ...) {
withr::local_options(bmm.sort_data = FALSE)
dots <- list(...)

formula <- object
model <- check_model(model, data, formula)
data <- check_data(model, data, formula)
formula <- check_formula(model, data, formula)
config_args <- configure_model(model, data, formula)
prior <- configure_prior(model, data, config_args$formula, user_prior = NULL)
prior <- configure_prior(model, data, config_args$formula, user_prior = dots$prior)
dots$prior <- NULL

dots <- list(...)
prior_args <- combine_args(nlist(config_args, dots, prior))
prior_args$object <- prior_args$formula
prior_args$formula <- NULL

brms_priors <- brms::do_call(brms::default_prior, prior_args)

combine_prior(brms_priors, prior_args$prior)
prior <- combine_prior(brms_priors, prior_args$prior)
add_links(prior, model)
}


Expand Down Expand Up @@ -273,3 +275,39 @@ summarise_default_prior <- function(prior_list) {
}
prior_info
}


# preprocess a brmsprior object for plotting
prep_brmsprior <- function(x, formula = NULL, links = NULL) {
stopif(!brms::is.brmsprior(x), "x must be a brmsprior object")
if (!is.null(formula)) {
stopif(!is.null(links), "Either formula or links should be provided")
stopif(!is_brmsformula(formula), "formula must be a brmsformula object")
stopif(is.null(formula$family), "formula must contain a brmsfamily object")
link_source <- formula
} else {
stopif(!is.null(links) && (!is.list(links) || is.null(names(links))),
"links must be a named list")
link_source <- links
}

priors <- x[x$prior != "",]
priors <- add_mu(priors)
priors <- priors[!grepl("constant", priors$prior),]

# rename priors to match distribution functions
priors$prior <- gsub("logistic", "logis", priors$prior)
priors$prior <- gsub("inv_gamma", "invgamma", priors$prior)

# get the parameter name regardless of type
priors$par <- ifelse(priors$nlpar != "", priors$nlpar, priors$dpar)

# remove corr parameters as it's too complicated to visualize them
priors <- priors[priors$class != "cor",]

# add a column with the corresponding parameter link
if (!("link" %in% names(x))) {
priors <- add_links(priors, object = link_source)
}
priors
}
39 changes: 39 additions & 0 deletions R/math.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
logit <- function(p) {
log(p) - log1p(-p)
}

inv_logit <- function(x) {
1 / (1 + exp(-1*x))
}

cloglog <- function (x) {
log(-log1p(-x))
}

inv_cloglog <- function(x) {
1 - exp(-1*exp(x))
}

log_expm1 <- function (x) {
out <- log(expm1(x))
ifelse(out < Inf, out, x)
}

log1p_exp <- function (x) {
out <- log1p(exp(x))
ifelse(out < Inf, out, x)
}

log1p_exp_custom <- function(x, mu = 0, a = 0, b = 1) {
out <- log1p(a + exp(b * (x - mu)))/b
ifelse(out < Inf, out, x)
}

softit <- function(x) {
log_expm1(x/(1-x))
}

inv_softit <- function(x) {
y <- log1p_exp(x)
y/(1 + y)
}
Loading
Loading