Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Delta.log_prob jit-able on metal and raise explicit error for Delta sampling site during initialization. #1950

Merged
merged 2 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def sample(self, key, sample_shape=()):

@validate_sample
def log_prob(self, value):
log_prob = jnp.log(value == self.v)
log_prob = jnp.where(value == self.v, 0, -jnp.inf)
log_prob = sum_rightmost(log_prob, len(self.event_shape))
return log_prob + self.log_density

Expand Down
13 changes: 13 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
Expand Down Expand Up @@ -685,6 +686,18 @@ def initialize_model(
has_enumerate_support,
model_trace,
) = _get_model_transforms(substituted_model, model_args, model_kwargs)

for name, site in model_trace.items():
if (
site["type"] == "sample"
and isinstance(site["fn"], dist.Delta)
and not site["is_observed"]
):
raise ValueError(
f"Sample site '{name}' has a delta distribution; use "
"`numpyro.deterministic` to add this value to the trace instead."
)

# substitute param sites from model_trace to model so
# we don't need to generate again parameters of `numpyro.module`
model = substitute(
Expand Down
15 changes: 15 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,18 @@ def model():
assert jnp.allclose(
jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel()
)


def test_autoguide_with_delta_site() -> None:
def model(x):
numpyro.sample("x", dist.Delta(3.0), obs=x)
# Need to sample a latent variable so the guide is not empty.
numpyro.sample("y", dist.Normal())

guide = AutoDiagonalNormal(lambda: model(None))
with pytest.raises(ValueError, match="has a delta distribution"):
numpyro.handlers.seed(guide, 9)()

# Check delta distributions are fine if observed.
guide = AutoDiagonalNormal(lambda: model(3.0))
numpyro.handlers.seed(guide, 9)()
Loading