Skip to content

Commit

Permalink
Merge pull request #206 from venpopov/bugfix/default-prior-interactio…
Browse files Browse the repository at this point in the history
…n-re

Bugfix for default priors when only an interaction is specified
  • Loading branch information
GidonFrischkorn authored Apr 9, 2024
2 parents 02cba98 + 18cae89 commit 174511f
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 28 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
### New features
* you can now specify to save the **bmmfit** object generated by **bmm()** to a file with the **file** argument, similarly to **brms::brm()** (#190)

### Bug fixes
* fix incorrect specification of default priors when only an interaction is specified (#201)

# bmm 0.5.1

### Bug fixes
Expand Down
14 changes: 7 additions & 7 deletions R/helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,30 @@ set_default_prior <- function(model, data, formula) {

prior <- brms::empty_prior()
bterms <- brms::brmsterms(formula)
dpars <- names(bterms$dpars)
bterms$allpars <- c(bterms$dpars, bterms$nlpars)
nlpars <- names(bterms$nlpars)
pars <- c(dpars, nlpars)
pars <- names(bterms$allpars)

pars_key <- names(default_priors)
pars <- pars[pars %in% pars_key]

for (par in pars) {
bform <- formula$pforms[[par]]
bform <- bterms$allpars[[par]]$fe
terms <- stats::terms(bform)
prior_desc <- default_priors[[par]]
has_effects_prior <- !is.null(prior_desc$effects)

all_rhs_names <- rhs_vars(bform)

all_rhs_terms <- attr(terms, "term.labels")
fixef <- all_rhs_terms[all_rhs_terms %in% all_rhs_names]
inter <- all_rhs_terms[attr(bterms, "order") > 1]
fixef <- all_rhs_terms[attr(terms, "order") == 1]
inter <- all_rhs_terms[attr(terms, "order") > 1]
nfixef <- length(fixef)
ninter <- length(inter)
interaction_only <- nfixef == 0 && ninter > 0

# if the user has specified a non-linear predictor on a model parameter, do
# not set prior
if (any(all_rhs_names %in% pars)) {
if (any(all_rhs_terms %in% pars)) {
next
}

Expand Down
199 changes: 199 additions & 0 deletions tests/testthat/test-default-priors.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@


test_that('default priors are set correctly with fixed effects only', {
data <- oberauer_lin_2017
model <- mixture2p('dev_rad')

# Intercept only
formula <- bmf(kappa ~ 1, thetat ~ 1)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("",""))

# 1 fixed effect + intercept
formula <- bmf(kappa ~ set_size, thetat ~ set_size)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))
expect_equal(pr[pr$coef == "Intercept", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_true(all(grepl("constant", pr[pr$dpar %in% c('mu1','mu2','kappa2'),]$prior)))

# 1 fixed effect intercept suppressed
formula <- bmf(kappa ~ 0 + set_size, thetat ~ 0 + set_size)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "Intercept", ]$prior, character(0))
expect_true(all(grepl("constant", pr[pr$dpar %in% c('mu1','mu2','kappa2'),]$prior)))

# 2 fixed effects + intercept
formula <- bmf(kappa ~ set_size + session, thetat ~ set_size + session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))

# 2 fixed effects + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size + session, thetat ~ 0 + set_size + session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))

# 2 fixed effects + interaction + intercept
formula <- bmf(kappa ~ set_size * session, thetat ~ set_size * session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))

# 2 fixed effects + interaction + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size * session, thetat ~ 0 + set_size * session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))
expect_equal(pr[pr$coef == "set_size2:session2" & pr$class == "b", ]$prior, c("",""))

# interaction only between 2 fixed effects
formula <- bmf(kappa ~ 0 + set_size:session, thetat ~ 0 + set_size:session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
})


test_that('default priors are set correctly with random effects', {
data <- oberauer_lin_2017
model <- mixture2p('dev_rad')

# Intercept only
formula <- bmf(kappa ~ 1 + (1|ID), thetat ~ 1 + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("",""))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 1 fixed effect + intercept
formula <- bmf(kappa ~ set_size + (1|ID), thetat ~ set_size + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_true(all(grepl("constant", pr[pr$dpar %in% c('mu1','mu2','kappa2'),]$prior)))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 1 fixed effect intercept suppressed
formula <- bmf(kappa ~ 0 + set_size + (1|ID), thetat ~ 0 + set_size + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, character(0))
expect_true(all(grepl("constant", pr[pr$dpar %in% c('mu1','mu2','kappa2'),]$prior)))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 2 fixed effects + intercept
formula <- bmf(kappa ~ set_size + session + (1|ID), thetat ~ set_size + session + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 2 fixed effects + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size + session + (1|ID), thetat ~ 0 + set_size + session + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 2 fixed effects + interaction + intercept
formula <- bmf(kappa ~ set_size * session + (1|ID), thetat ~ set_size * session + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("","normal(0, 1)"))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# 2 fixed effects + interaction + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size * session + (1|ID), thetat ~ 0 + set_size * session + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))
expect_equal(pr[pr$coef == "set_size2:session2" & pr$class == "b", ]$prior, c("",""))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))

# interaction only between 2 fixed effects
formula <- bmf(kappa ~ 0 + set_size:session + (1|ID), thetat ~ 0 + set_size:session + (1|ID))
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "Intercept" & pr$class == "b", ]$prior, character(0))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
expect_equal(unique(pr[pr$class == "sd", ]$prior), c("student_t(3, 0, 2.5)", ""))
})


test_that('default priors are set correctly with fixed effects only and sdm model', {
data <- oberauer_lin_2017
model <- sdm('dev_rad')

# Intercept only
formula <- bmf(kappa ~ 1, c ~ 1)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)", "constant(0)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, character(0))

# 1 fixed effect + intercept
formula <- bmf(kappa ~ set_size, c ~ set_size)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(0, 1)","normal(0, 1)"))
expect_equal(pr[pr$class == "Intercept", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)", "constant(0)"))

# 1 fixed effect intercept suppressed
formula <- bmf(kappa ~ 0 + set_size, c ~ 0 + set_size)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)"))
expect_equal(pr[pr$class == "Intercept", ]$prior, c("constant(0)"))

# 2 fixed effects + intercept
formula <- bmf(kappa ~ set_size + session, c ~ set_size + session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)", "constant(0)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(0, 1)","normal(0, 1)"))

# 2 fixed effects + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size + session, c ~ 0 + set_size + session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, "constant(0)")
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(0, 1)", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))

# 2 fixed effects + interaction + intercept
formula <- bmf(kappa ~ set_size * session, c ~ set_size * session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)", "constant(0)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(0, 1)","normal(0, 1)"))

# 2 fixed effects + interaction + intercept suppressed
formula <- bmf(kappa ~ 0 + set_size * session, c ~ 0 + set_size * session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, c("constant(0)"))
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("normal(0, 1)", "normal(0, 1)"))
expect_equal(pr[pr$coef == "set_size1" & pr$class == "b", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)"))
expect_equal(pr[pr$coef == "session2" & pr$class == "b", ]$prior, c("",""))
expect_equal(pr[pr$coef == "set_size2:session2" & pr$class == "b", ]$prior, c("",""))

# interaction only between 2 fixed effects
formula <- bmf(kappa ~ 0 + set_size:session, c ~ 0 + set_size:session)
pr <- default_prior(formula, data, model)
expect_equal(pr[pr$class == "Intercept", ]$prior, "constant(0)")
expect_equal(pr[pr$coef == "" & pr$class == "b", ]$prior, c("student_t(5, 2, 0.75)", "student_t(5, 1.75, 0.75)"))
})


test_that("default priors work when there are no fixed parameters", {
formula <- bmf(mu ~ 1,
c ~ 1,
kappa ~ 1)

pr <- default_prior(formula, oberauer_lin_2017, sdm('dev_rad'))
expect_s3_class(pr, 'brmsprior')
})
21 changes: 0 additions & 21 deletions tests/testthat/test-helpers-prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,6 @@ test_that("in combine prior, prior2 overwrites only shared components with prior
expect_equal(dplyr::filter(prior, dpar == "kappa"), dplyr::filter(prior2, dpar == "kappa"))
})


test_that("default priors are returned correctly", {
dp <- default_prior(bmf(kappa ~ set_size, thetat ~ set_size),
oberauer_lin_2017,
mixture2p('dev_rad'))
expect_equal(dp[dp$coef == "" & dp$class == "b", ]$prior, c("","normal(0, 1)"))
expect_equal(dp[dp$coef == "Intercept", ]$prior, c("normal(2, 1)", "logistic(0, 1)"))
})

test_that("no check for sort_data with default_priors function", {
withr::local_options('bmm.sort_data' = 'check')
res <- capture_messages(default_prior(bmf(kappa ~ set_size, c ~ set_size),
Expand All @@ -62,16 +53,4 @@ test_that("no check for sort_data with default_priors function", {
})


test_that("default priors work when there are no fixed parameters", {
formula <- bmf(mu ~ 1,
c ~ 1,
kappa ~ 1)
if (utils::packageVersion("brms") >= "2.20.14") {
prior_fn <- default_prior
} else {
prior_fn <- get_model_prior
}

pr <- prior_fn(formula, oberauer_lin_2017, sdm('dev_rad'))
expect_s3_class(pr, 'brmsprior')
})

0 comments on commit 174511f

Please sign in to comment.