Skip to content

Commit

Permalink
add batch dim for x, add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Dec 17, 2024
1 parent 074efa0 commit 8a95c8f
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 46 deletions.
92 changes: 56 additions & 36 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,28 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore

def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable:
"""Returns a potential conditioned on a subset of theta dimensions.
def condition_on_theta(
self, theta_condition: Tensor, dims_to_sample: List[int]
) -> Callable:
"""Returns a potential function conditioned on a subset of theta dimensions.
The condition is a part of theta, but is assumed to correspond to a batch of iid
x_o.
x_o. For example, it can be a batch of experimental conditions that corresponds
to a batch of i.i.d. trials in x_o.
Args:
condition: The condition to fix.
dims_to_sample: The indices of the parameters to sample.
theta_condition: The condition values to be conditioned.
dims_to_sample: The indices of the columns in theta that will be sampled,
i.e., that *not* conditioned. For example, if original theta has shape
`(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will
set `theta[:, 3] = theta_condition` at inference time.
Returns:
A potential function conditioned on the condition.
A potential function conditioned on the theta_condition.
"""

assert self.x_is_iid, "Conditioning is only supported for iid data."

Check warning on line 139 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L139

Added line #L139 was not covered by tests

def conditioned_potential(

Check warning on line 141 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L141

Added line #L141 was not covered by tests
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
Expand All @@ -138,10 +146,10 @@ def conditioned_potential(
), "dims_to_sample must match the number of parameters to sample."
theta_without_condition = theta[:, dims_to_sample]

Check warning on line 147 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L147

Added line #L147 was not covered by tests

return _log_likelihood_with_iid_condition(
return _log_likelihood_over_iid_conditions(

Check warning on line 149 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L149

Added line #L149 was not covered by tests
x=x_o if x_o is not None else self.x_o,
theta_without_condition=theta_without_condition,
condition=condition,
condition=theta_condition,
estimator=self.likelihood_estimator,
track_gradients=track_gradients,
)
Expand Down Expand Up @@ -205,63 +213,75 @@ def _log_likelihoods_over_trials(
return log_likelihood_trial_sum


def _log_likelihood_with_iid_condition(
def _log_likelihood_over_iid_conditions(
x: Tensor,
theta_without_condition: Tensor,
condition: Tensor,
estimator: ConditionalDensityEstimator,
track_gradients: bool = False,
) -> Tensor:
"""Return log likelihoods summed over iid trials of `x` with a matching batch of
conditions.
"""Returns $\\log(p(x_o|\theta, condition)$, where x_o is a batch of iid data, and
condition is a matching batch of conditions.
This function is different from `_log_likelihoods_over_trials` in that it moves the
iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when
iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
the likelihood estimator is conditioned on a batch of conditions that are iid with
the batch of `x`. It avoid the evaluation of the likelihood for every combination of
`x` and `condition`. Instead, it manually constructs a batch covering all
combination of iid trial and theta batch and reshapes to sum over the iid
the batch of `x`. It avoids the evaluation of the likelihood for every combination
of `x` and `condition`. Instead, it manually constructs a batch covering all
combination of iid trials and theta batch and reshapes to sum over the iid
likelihoods.
Args:
x: Batch of iid data of shape `(iid_dim, *event_shape)`.
theta_without_condition: Batch of parameters `(batch_dim, *event_shape)`
condition: Batch of conditions of shape `(iid_dim, *condition_shape)`.
x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
observations.
theta_without_condition: Batch of parameters `(theta_batch_dim,
num_parameters)`.
condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must
match x's `sample_dim`.
estimator: DensityEstimator.
track_gradients: Whether to track gradients.
Returns:
log_likelihood_trial_sum: log likelihood for each parameter, summed over all
batch entries (iid trials) in `x`.
log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
theta_batch_dim, summed over all i.i.d. trials. Shape
`(x_batch_dim, theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert (
condition.shape[0] == x.shape[0]
), "Condition and iid x must have the same batch size."
num_trials = x.shape[0]
num_theta = theta_without_condition.shape[0]
x = reshape_to_sample_batch_event(
x, event_shape=x.shape[1:], leading_is_sample=True
)
condition.dim() == 2
), "condition must have shape (sample_dim, num_conditions)."
assert (
theta_without_condition.dim() == 2
), "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = theta_without_condition.shape[0]
assert (
condition.shape[0] == num_trials
), "Condition batch size must match the number of iid trials in x."

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1)
# for this to work we construct theta and condition to cover all combinations in the
# trial batch and the theta batch.
theta = torch.cat(
x.transpose_(0, 1)
x_repeated = x.repeat_interleave(num_thetas, dim=1)

# construct theta and condition to cover all trial-theta combinations
theta_with_condition = torch.cat(
[
theta_without_condition.repeat(num_trials, 1), # repeat ABAB
condition.repeat_interleave(num_theta, dim=0), # repeat AABB
condition.repeat_interleave(num_thetas, dim=0), # repeat AABB
],
dim=-1,
)

with torch.set_grad_enabled(track_gradients):
# Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size)
log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta)
# Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
log_likelihood_trial_batch = estimator.log_prob(
x_repeated, condition=theta_with_condition
)
# Reshape to (x-trials x parameters), sum over trial-log likelihoods.
log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
num_trials, num_theta
).sum(0)
num_xs, num_trials, num_thetas
).sum(1)

return log_likelihood_trial_sum

Expand Down
101 changes: 91 additions & 10 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
_log_likelihood_over_iid_conditions,
likelihood_estimator_based_potential,
)
from sbi.neural_nets import likelihood_nn
Expand All @@ -39,6 +40,15 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.
return torch.cat((rts, choices), dim=1)


def wrapped_simulator(
theta_and_condition: Tensor, last_idx_parameters: int = 2
) -> Tensor:
# simulate with experiment conditions
theta = theta_and_condition[:, :last_idx_parameters]
condition = theta_and_condition[:, last_idx_parameters:]
return mixed_simulator(theta, condition)


@pytest.mark.mcmc
@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "gpu"))
Expand Down Expand Up @@ -256,14 +266,6 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
num_simulations = 10000
num_samples = 1000

def sim_wrapper(
theta_and_condition: Tensor, last_idx_parameters: int = 2
) -> Tensor:
# simulate with experiment conditions
theta = theta_and_condition[:, :last_idx_parameters]
condition = theta_and_condition[:, last_idx_parameters:]
return mixed_simulator(theta, condition)

proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Expand All @@ -274,7 +276,7 @@ def sim_wrapper(
)

theta = proposal.sample((num_simulations,))
x = sim_wrapper(theta)
x = wrapped_simulator(theta)
assert x.shape == (num_simulations, 2)

num_trials = 10
Expand All @@ -285,7 +287,7 @@ def sim_wrapper(
condition_o = theta_and_condition[:, 2:]
theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)

x_o = sim_wrapper(theta_and_conditions_o)
x_o = wrapped_simulator(theta_and_conditions_o)

mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
Expand Down Expand Up @@ -331,3 +333,82 @@ def sim_wrapper(
true_posterior_samples,
alg=f"MNLE trained with {num_simulations} simulations",
)


@pytest.mark.parametrize("num_thetas", [1, 10])
@pytest.mark.parametrize("num_trials", [1, 5])
@pytest.mark.parametrize("num_xs", [1, 3])
@pytest.mark.parametrize(
"num_conditions",
[
1,
pytest.param(
2,
marks=pytest.mark.xfail(
reason="Batched theta_condition is not " "supported"
),
),
],
)
def test_log_likelihood_over_iid_conditions(
num_thetas, num_trials, num_xs, num_conditions
):
"""Test log likelihood over iid conditions using MNLE.
Args:
num_thetas: batch of theta to condition on.
num_trials: number of i.i.d. trials in x
num_xs: batch of x, e.g., different subjects in a study.
num_conditions: number of batches of conditions, e.g., different conditions
for each x (not implemented yet).
"""

# train mnle on mixed data
trainer = MNLE(
density_estimator=likelihood_nn(model="mnle", z_score_x=None),
)
proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
],
validate_args=False,
)

num_simulations = 100
theta = proposal.sample((num_simulations,))
x = wrapped_simulator(theta)
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)

# condition on multiple conditions
theta_o = proposal.sample((num_xs,))[:, :2]

x_o = torch.zeros(num_trials, num_xs, 2)
condition_o = proposal.sample((
num_conditions,
num_trials,
))[:, 2:].reshape(num_trials, 1)
for i in range(num_xs):
# simulate with same iid theta but different conditions
x_o[:, i, :] = mixed_simulator(theta_o[i].repeat(num_trials, 1), condition_o)

# batched conditioning
theta = proposal.sample((num_thetas,))[:, :2]
# x_o has shape (batch, iid, *event)
# condition_o has shape (batch, iid, num_conditions)
ll_batched = _log_likelihood_over_iid_conditions(x_o, theta, condition_o, estimator)

# looped conditioning
ll_single = []
for i in range(num_trials):
theta_and_condition = torch.cat(
(theta, condition_o[i].repeat(num_thetas, 1)), dim=1
)
x_i = x_o[:, i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
ll_single = torch.stack(ll_single).sum(0) # sum over trials

assert ll_batched.shape == torch.Size([num_xs, num_thetas])
assert ll_batched.shape == ll_single.shape
assert torch.allclose(ll_batched, ll_single, atol=1e-5)

0 comments on commit 8a95c8f

Please sign in to comment.