diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index 5100376ca..3cec9aa0b 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -40,6 +40,10 @@ log_density ----------- .. autofunction:: numpyro.infer.util.log_density +compute_log_probs +----------------- +.. autofunction:: numpyro.infer.util.compute_log_probs + get_transforms -------------- .. autofunction:: numpyro.infer.util.get_transforms diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 1ff0b5c5a..a3b7425d0 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -54,21 +54,28 @@ def process_message(self, msg): msg["value"] = random.PRNGKey(0) -def log_density(model, model_args, model_kwargs, params): +def compute_log_probs( + model, + model_args: tuple, + model_kwargs: dict, + params: dict, + sum_log_prob: bool = True, +): """ - (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given + (EXPERIMENTAL INTERFACE) Computes log of density for each site of the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. - :return: log of joint density and a corresponding model trace + :param model_args: args provided to the model. + :param model_kwargs: kwargs provided to the model. + :param params: Dictionary of current parameter values keyed by site name. + :param sum_log_prob: sum log probability over batch dimensions. + :return: Dictionary mapping site names to log of density and a corresponding model + trace. """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) - log_joint = jnp.zeros(()) + log_joint = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] @@ -94,11 +101,28 @@ def log_density(model, model_args, model_kwargs, params): if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob - log_prob = jnp.sum(log_prob) - log_joint = log_joint + log_prob + log_joint[site["name"]] = jnp.sum(log_prob) if sum_log_prob else log_prob return log_joint, model_trace +def log_density(model, model_args: tuple, model_kwargs: dict, params: dict): + """ + (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent + values ``params``. + + :param model: Python callable containing NumPyro primitives. + :param model_args: args provided to the model. + :param model_kwargs: kwargs provided to the model. + :param params: Dictionary of current parameter values keyed by site name. + :return: Log of joint density and a corresponding model trace. + """ + log_joint, model_trace = compute_log_probs(model, model_args, model_kwargs, params) + # We need to start with 0.0 instead of 0 because log_joint may be empty or only + # contain integers, but log_density must be a floating point value to be + # differentiable by jax. + return sum(log_joint.values(), start=0.0), model_trace + + class _without_rsample_stop_gradient(numpyro.primitives.Messenger): """ Stop gradient for samples at latent sample sites for which has_rsample=False. diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index a1342507f..584cd457b 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -224,6 +224,7 @@ def model(data, labels): ) +@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_haiku_state_dropout_smoke(dropout, batchnorm): @@ -263,6 +264,7 @@ def model(): svi.run(random.PRNGKey(100), 10) +@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_flax_state_dropout_smoke(dropout, batchnorm): diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 42b1a5d62..ab0133700 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -27,8 +27,10 @@ from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import ( Predictive, + compute_log_probs, constrain_fn, initialize_model, + log_density, log_likelihood, potential_energy, transform_fn, @@ -266,6 +268,23 @@ def test_log_likelihood(batch_shape): ) +def test_compute_log_probs(): + model, data, _ = beta_bernoulli() + samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7)) + samples = {key: value[0] for key, value in samples.items()} + + logden, _ = log_density(model, (data,), {}, samples) + assert logden.shape == () + + logdens, _ = compute_log_probs(model, (data,), {}, samples) + assert set(logdens) == {"beta", "obs"} + assert all(x.shape == () for x in logdens.values()) + + logdens, _ = compute_log_probs(model, (data,), {}, samples, False) + assert logdens["beta"].shape == (2,) + assert logdens["obs"].shape == (800, 2) + + def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.0) # transformed distribution