Skip to content

Commit

Permalink
Make Delta.log_prob jit-able on metal and raise explicit error for …
Browse files Browse the repository at this point in the history
…`Delta` sampling site during initialization. (#1950)

* Use `jnp.where` for `Delta.log_prob` (cf. jax-ml/jax#25935).

* Raise explicit error for unobserved `Delta` sample sites during intialization.
  • Loading branch information
tillahoffmann authored Jan 17, 2025
1 parent 986da95 commit 8a67269
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def sample(self, key, sample_shape=()):

@validate_sample
def log_prob(self, value):
log_prob = jnp.log(value == self.v)
log_prob = jnp.where(value == self.v, 0, -jnp.inf)
log_prob = sum_rightmost(log_prob, len(self.event_shape))
return log_prob + self.log_density

Expand Down
13 changes: 13 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax.numpy as jnp

import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
Expand Down Expand Up @@ -685,6 +686,18 @@ def initialize_model(
has_enumerate_support,
model_trace,
) = _get_model_transforms(substituted_model, model_args, model_kwargs)

for name, site in model_trace.items():
if (
site["type"] == "sample"
and isinstance(site["fn"], dist.Delta)
and not site["is_observed"]
):
raise ValueError(
f"Sample site '{name}' has a delta distribution; use "
"`numpyro.deterministic` to add this value to the trace instead."
)

# substitute param sites from model_trace to model so
# we don't need to generate again parameters of `numpyro.module`
model = substitute(
Expand Down
15 changes: 15 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,18 @@ def model():
assert jnp.allclose(
jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel()
)


def test_autoguide_with_delta_site() -> None:
def model(x):
numpyro.sample("x", dist.Delta(3.0), obs=x)
# Need to sample a latent variable so the guide is not empty.
numpyro.sample("y", dist.Normal())

guide = AutoDiagonalNormal(lambda: model(None))
with pytest.raises(ValueError, match="has a delta distribution"):
numpyro.handlers.seed(guide, 9)()

# Check delta distributions are fine if observed.
guide = AutoDiagonalNormal(lambda: model(3.0))
numpyro.handlers.seed(guide, 9)()

0 comments on commit 8a67269

Please sign in to comment.