Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Blackjax sampler suffers divergences with TruncatedNormal likelihood #775

Open
noahg2 opened this issue Jan 31, 2025 · 0 comments
Open

Blackjax sampler suffers divergences with TruncatedNormal likelihood #775

noahg2 opened this issue Jan 31, 2025 · 0 comments

Comments

@noahg2
Copy link

noahg2 commented Jan 31, 2025

Describe the issue as clearly as possible:

It seems that the PyMC blackjax sampler struggles to sample from models that have a truncated likelihood. In particular, the sampler often fails to converge with many divergences. Based on this initial discussion, it seems that this is likely to be a problem with the sampler because other samplers complete successfully on the same model, and there doesn't appear to be any issues with geometry at play.

Screenshot of erroneous output:

Image

Steps/code to reproduce the bug:

import pymc as pm

N_OBSERVATIONS = 50

with pm.Model() as model:
    mu = pm.Normal("mu")
    sigma = pm.HalfNormal("sigma", sigma=0.5)
    y = pm.TruncatedNormal("y", mu=mu, sigma=sigma, lower=-10, upper=10, size=(N_OBSERVATIONS,))
    prior_trace = pm.sample_prior_predictive(random_seed=100)

data = prior_trace.prior.y.isel(chain=0, draw=0)
with pm.observe(model, {y: data}):
    idata = pm.sample(nuts_sampler="blackjax")
    idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

Expected result:

The sampler should complete with no divergences with posteriors similar to those of the PyMC NUTS.

Error message:

Blackjax/JAX/jaxlib/Python version information:

BlackJax 1.2.4
Python 3.11.11 | packaged by conda-forge | (main, Dec  5 2024, 14:21:42) [Clang 18.1.8 ]
Jax 0.4.31
Jaxlib 0.4.31
PyMC 5.20.0

Context for the issue:

This issue appears to render the sampler unable to fit models with certain truncated likelihoods, which are a useful construct in a number of domains.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant