Skip to content

Commit

Permalink
feat: NLE with multiple iid conditions (sbi-dev#1331)
Browse files Browse the repository at this point in the history
* add method for iid-batched conditioning.

- deprecate MNLE-based potential (can be nle-based)
- adapt tests for conditioned mnle.

* update notebook, bugfixes

* add batch dim for x, add test.

* fix shape handling, adapt tutorial.
  • Loading branch information
janfb authored Dec 21, 2024
1 parent 390a518 commit e7940dc
Show file tree
Hide file tree
Showing 7 changed files with 575 additions and 190 deletions.
137 changes: 135 additions & 2 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Optional, Tuple
import warnings
from typing import Callable, List, Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore

def condition_on_theta(
self, local_theta: Tensor, dims_global_theta: List[int]
) -> Callable:
r"""Returns a potential function conditioned on a subset of theta dimensions.
The goal of this function is to divide the original `theta` into a
`global_theta` we do inference over, and a `local_theta` we condition on (in
addition to conditioning on `x_o`). Thus, the returned potential function will
calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
and `local_theta_i` are fixed and `global_theta` varies at inference time.
Args:
local_theta: The condition values to be conditioned.
dims_global_theta: The indices of the columns in `theta` that will be
sampled, i.e., that *not* conditioned. For example, if original theta
has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
potential will set `theta[:, 3] = local_theta` at inference time.
Returns:
A potential function conditioned on the `local_theta`.
"""

assert self.x_is_iid, "Conditioning is only supported for iid data."

def conditioned_potential(
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert (
len(dims_global_theta) == theta.shape[1]
), "dims_global_theta must match the number of parameters to sample."
global_theta = theta[:, dims_global_theta]
x_o = x_o if x_o is not None else self.x_o
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
if x_o.dim() < 3:
x_o = reshape_to_sample_batch_event(
x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
)

return _log_likelihood_over_iid_trials_and_local_theta(
x=x_o,
global_theta=global_theta,
local_theta=local_theta,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)

return conditioned_potential


def _log_likelihoods_over_trials(
x: Tensor,
Expand Down Expand Up @@ -172,6 +221,77 @@ def _log_likelihoods_over_trials(
return log_likelihood_trial_sum


def _log_likelihood_over_iid_trials_and_local_theta(
x: Tensor,
global_theta: Tensor,
local_theta: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.
`x` is a batch of iid data, and `local_theta` is a matching batch of condition
values that were part of `theta` but are treated as local iid variables at inference
time.
This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoids the evaluation of the likelihood for every combination
of `x` and `local_theta`.
Args:
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
observations.
global_theta: Batch of parameters `(theta_batch_dim,
num_parameters)`.
local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
match x's `sample_dim`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert (
local_theta.dim() == 2
), "condition must have shape (sample_dim, num_conditions)."
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = global_theta.shape[0]
assert (
local_theta.shape[0] == num_trials
), "Condition batch size must match the number of iid trials in x."

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)

# construct theta and condition to cover all trial-theta combinations
theta_with_condition = torch.cat(
[
global_theta.repeat(num_trials, 1), # repeat ABAB
local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
],
dim=-1,
)

with torch.set_grad_enabled(track_gradients):
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
log_likelihood_trial_batch = estimator.log_prob(
x_repeated, condition=theta_with_condition
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
num_xs, num_trials, num_thetas
).sum(1)

return log_likelihood_trial_sum


def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
Expand All @@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential(
to unconstrained space.
"""

warnings.warn(
"This function is deprecated and will be removed in a future release. Use "
"`likelihood_estimator_based_potential` instead.",
DeprecationWarning,
stacklevel=2,
)

device = str(next(likelihood_estimator.discrete_net.parameters()).device)

potential_fn = MixedLikelihoodBasedPotential(
Expand All @@ -212,6 +339,13 @@ def __init__(
):
super().__init__(likelihood_estimator, prior, x_o, device)

warnings.warn(
"This function is deprecated and will be removed in a future release. Use "
"`LikelihoodBasedPotential` instead.",
DeprecationWarning,
stacklevel=2,
)

def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore

Expand All @@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
# TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/trainers/nle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import Distribution

from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
Expand Down Expand Up @@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)
) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)

if sample_with == "mcmc":
self._posterior = MCMCPosterior(
Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/conditional_density_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
if condition.shape[0] != 1:
if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")

self.potential_fn = potential_fn
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -

if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
"datapoints. Before z-scoring, it had been {num_unique}. This can "
f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
Expand Down
Loading

0 comments on commit e7940dc

Please sign in to comment.