From ac8988467fd5f8bf0aec02f1f781179d1b3fd567 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 5 Dec 2024 15:14:17 -0500 Subject: [PATCH 1/3] Add utility function to evaluate log density for individual sites. --- docs/source/utilities.rst | 4 ++++ numpyro/infer/util.py | 44 +++++++++++++++++++++++++++-------- test/infer/test_infer_util.py | 19 +++++++++++++++ 3 files changed, 57 insertions(+), 10 deletions(-) diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index 5100376ca..1fc7ad7c7 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -40,6 +40,10 @@ log_density ----------- .. autofunction:: numpyro.infer.util.log_density +log_densities +------------- +.. autofunction:: numpyro.infer.util.log_densities + get_transforms -------------- .. autofunction:: numpyro.infer.util.get_transforms diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 1ff0b5c5a..478971228 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -54,21 +54,28 @@ def process_message(self, msg): msg["value"] = random.PRNGKey(0) -def log_density(model, model_args, model_kwargs, params): +def log_densities( + model, + model_args: tuple, + model_kwargs: dict, + params: dict, + sum_log_prob: bool = True, +): """ - (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given + (EXPERIMENTAL INTERFACE) Computes log of density for each site of the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. - :param dict params: dictionary of current parameter values keyed by site - name. - :return: log of joint density and a corresponding model trace + :param model_args: args provided to the model. + :param model_kwargs: kwargs provided to the model. + :param params: Dictionary of current parameter values keyed by site name. + :param sum_log_prob: sum log probability over batch dimensions. + :return: Dictionary mapping site names to log of density and a corresponding model + trace. """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) - log_joint = jnp.zeros(()) + log_joint = {} for site in model_trace.values(): if site["type"] == "sample": value = site["value"] @@ -94,11 +101,28 @@ def log_density(model, model_args, model_kwargs, params): if (scale is not None) and (not is_identically_one(scale)): log_prob = scale * log_prob - log_prob = jnp.sum(log_prob) - log_joint = log_joint + log_prob + log_joint[site["name"]] = jnp.sum(log_prob) if sum_log_prob else log_prob return log_joint, model_trace +def log_density(model, model_args: tuple, model_kwargs: dict, params: dict): + """ + (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent + values ``params``. + + :param model: Python callable containing NumPyro primitives. + :param model_args: args provided to the model. + :param model_kwargs: kwargs provided to the model. + :param params: Dictionary of current parameter values keyed by site name. + :return: Log of joint density and a corresponding model trace. + """ + log_joint, model_trace = log_densities(model, model_args, model_kwargs, params) + # We need to start with 0.0 instead of 0 because log_joint may be empty or only + # contain integers, but log_density must be a floating point value to be + # differentiable by jax. + return sum(log_joint.values(), start=0.0), model_trace + + class _without_rsample_stop_gradient(numpyro.primitives.Messenger): """ Stop gradient for samples at latent sample sites for which has_rsample=False. diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 42b1a5d62..0505f16fa 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -29,6 +29,8 @@ Predictive, constrain_fn, initialize_model, + log_densities, + log_density, log_likelihood, potential_energy, transform_fn, @@ -266,6 +268,23 @@ def test_log_likelihood(batch_shape): ) +def test_log_densities(): + model, data, _ = beta_bernoulli() + samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7)) + samples = {key: value[0] for key, value in samples.items()} + + logden, _ = log_density(model, (data,), {}, samples) + assert logden.shape == () + + logdens, _ = log_densities(model, (data,), {}, samples) + assert set(logdens) == {"beta", "obs"} + assert all(x.shape == () for x in logdens.values()) + + logdens, _ = log_densities(model, (data,), {}, samples, False) + assert logdens["beta"].shape == (2,) + assert logdens["obs"].shape == (800, 2) + + def test_model_with_transformed_distribution(): x_prior = dist.HalfNormal(2) y_prior = dist.LogNormal(scale=3.0) # transformed distribution From f2877519c52112dc2382c40f60eaddff9956792c Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 5 Dec 2024 18:24:33 -0500 Subject: [PATCH 2/3] Rename `log_densities` to `compute_log_probs`. --- docs/source/utilities.rst | 6 +++--- numpyro/infer/util.py | 4 ++-- test/infer/test_infer_util.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index 1fc7ad7c7..3cec9aa0b 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -40,9 +40,9 @@ log_density ----------- .. autofunction:: numpyro.infer.util.log_density -log_densities -------------- -.. autofunction:: numpyro.infer.util.log_densities +compute_log_probs +----------------- +.. autofunction:: numpyro.infer.util.compute_log_probs get_transforms -------------- diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 478971228..a3b7425d0 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -54,7 +54,7 @@ def process_message(self, msg): msg["value"] = random.PRNGKey(0) -def log_densities( +def compute_log_probs( model, model_args: tuple, model_kwargs: dict, @@ -116,7 +116,7 @@ def log_density(model, model_args: tuple, model_kwargs: dict, params: dict): :param params: Dictionary of current parameter values keyed by site name. :return: Log of joint density and a corresponding model trace. """ - log_joint, model_trace = log_densities(model, model_args, model_kwargs, params) + log_joint, model_trace = compute_log_probs(model, model_args, model_kwargs, params) # We need to start with 0.0 instead of 0 because log_joint may be empty or only # contain integers, but log_density must be a floating point value to be # differentiable by jax. diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 0505f16fa..ab0133700 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -27,9 +27,9 @@ from numpyro.infer.reparam import TransformReparam from numpyro.infer.util import ( Predictive, + compute_log_probs, constrain_fn, initialize_model, - log_densities, log_density, log_likelihood, potential_energy, @@ -268,7 +268,7 @@ def test_log_likelihood(batch_shape): ) -def test_log_densities(): +def test_compute_log_probs(): model, data, _ = beta_bernoulli() samples = Predictive(model, return_sites=["beta"], num_samples=1)(random.key(7)) samples = {key: value[0] for key, value in samples.items()} @@ -276,11 +276,11 @@ def test_log_densities(): logden, _ = log_density(model, (data,), {}, samples) assert logden.shape == () - logdens, _ = log_densities(model, (data,), {}, samples) + logdens, _ = compute_log_probs(model, (data,), {}, samples) assert set(logdens) == {"beta", "obs"} assert all(x.shape == () for x in logdens.values()) - logdens, _ = log_densities(model, (data,), {}, samples, False) + logdens, _ = compute_log_probs(model, (data,), {}, samples, False) assert logdens["beta"].shape == (2,) assert logdens["obs"].shape == (800, 2) From 33fda42649957e5513f3abe595824b111f1e605a Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 6 Dec 2024 09:59:52 -0500 Subject: [PATCH 3/3] Mark haiku and flax dropout as `xfail`. --- test/contrib/test_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index a1342507f..584cd457b 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -224,6 +224,7 @@ def model(data, labels): ) +@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_haiku_state_dropout_smoke(dropout, batchnorm): @@ -263,6 +264,7 @@ def model(): svi.run(random.PRNGKey(100), 10) +@pytest.mark.xfail(reason="fails due to upgrade from jax 0.4.35 to 0.4.36") @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) def test_flax_state_dropout_smoke(dropout, batchnorm):