From 8a95c8f4253ae8adc6e4b70a0e275a3bac7162d0 Mon Sep 17 00:00:00 2001 From: Jan Date: Tue, 17 Dec 2024 13:49:45 +0100 Subject: [PATCH] add batch dim for x, add test. --- .../potentials/likelihood_based_potential.py | 92 +++++++++------- tests/mnle_test.py | 101 ++++++++++++++++-- 2 files changed, 147 insertions(+), 46 deletions(-) diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index f82d8443c..b7a732d66 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -116,20 +116,28 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: ) return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore - def condition_on(self, condition: Tensor, dims_to_sample: List[int]) -> Callable: - """Returns a potential conditioned on a subset of theta dimensions. + def condition_on_theta( + self, theta_condition: Tensor, dims_to_sample: List[int] + ) -> Callable: + """Returns a potential function conditioned on a subset of theta dimensions. The condition is a part of theta, but is assumed to correspond to a batch of iid - x_o. + x_o. For example, it can be a batch of experimental conditions that corresponds + to a batch of i.i.d. trials in x_o. Args: - condition: The condition to fix. - dims_to_sample: The indices of the parameters to sample. + theta_condition: The condition values to be conditioned. + dims_to_sample: The indices of the columns in theta that will be sampled, + i.e., that *not* conditioned. For example, if original theta has shape + `(batch_dim, 3)`, and `dims_to_sample=[0, 1]`, then the potential will + set `theta[:, 3] = theta_condition` at inference time. Returns: - A potential function conditioned on the condition. + A potential function conditioned on the theta_condition. """ + assert self.x_is_iid, "Conditioning is only supported for iid data." + def conditioned_potential( theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True ) -> Tensor: @@ -138,10 +146,10 @@ def conditioned_potential( ), "dims_to_sample must match the number of parameters to sample." theta_without_condition = theta[:, dims_to_sample] - return _log_likelihood_with_iid_condition( + return _log_likelihood_over_iid_conditions( x=x_o if x_o is not None else self.x_o, theta_without_condition=theta_without_condition, - condition=condition, + condition=theta_condition, estimator=self.likelihood_estimator, track_gradients=track_gradients, ) @@ -205,63 +213,75 @@ def _log_likelihoods_over_trials( return log_likelihood_trial_sum -def _log_likelihood_with_iid_condition( +def _log_likelihood_over_iid_conditions( x: Tensor, theta_without_condition: Tensor, condition: Tensor, estimator: ConditionalDensityEstimator, track_gradients: bool = False, ) -> Tensor: - """Return log likelihoods summed over iid trials of `x` with a matching batch of - conditions. + """Returns $\\log(p(x_o|\theta, condition)$, where x_o is a batch of iid data, and + condition is a matching batch of conditions. This function is different from `_log_likelihoods_over_trials` in that it moves the - iid batch dimension of `x` onto the batch dimension of `theta`. This is useful when + iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when the likelihood estimator is conditioned on a batch of conditions that are iid with - the batch of `x`. It avoid the evaluation of the likelihood for every combination of - `x` and `condition`. Instead, it manually constructs a batch covering all - combination of iid trial and theta batch and reshapes to sum over the iid + the batch of `x`. It avoids the evaluation of the likelihood for every combination + of `x` and `condition`. Instead, it manually constructs a batch covering all + combination of iid trials and theta batch and reshapes to sum over the iid likelihoods. Args: - x: Batch of iid data of shape `(iid_dim, *event_shape)`. - theta_without_condition: Batch of parameters `(batch_dim, *event_shape)` - condition: Batch of conditions of shape `(iid_dim, *condition_shape)`. + x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim + holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid + observations. + theta_without_condition: Batch of parameters `(theta_batch_dim, + num_parameters)`. + condition: Batch of conditions of shape `(sample_dim, num_conditions)`, must + match x's `sample_dim`. estimator: DensityEstimator. track_gradients: Whether to track gradients. Returns: - log_likelihood_trial_sum: log likelihood for each parameter, summed over all - batch entries (iid trials) in `x`. + log_likelihood: log likelihood for each x in x_batch_dim, for each theta in + theta_batch_dim, summed over all i.i.d. trials. Shape + `(x_batch_dim, theta_batch_dim)`. """ + assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)." assert ( - condition.shape[0] == x.shape[0] - ), "Condition and iid x must have the same batch size." - num_trials = x.shape[0] - num_theta = theta_without_condition.shape[0] - x = reshape_to_sample_batch_event( - x, event_shape=x.shape[1:], leading_is_sample=True - ) + condition.dim() == 2 + ), "condition must have shape (sample_dim, num_conditions)." + assert ( + theta_without_condition.dim() == 2 + ), "theta must have shape (batch_dim, num_parameters)." + num_trials, num_xs = x.shape[:2] + num_thetas = theta_without_condition.shape[0] + assert ( + condition.shape[0] == num_trials + ), "Condition batch size must match the number of iid trials in x." # move the iid batch dimension onto the batch dimension of theta and repeat it there - x_expanded = x.reshape(1, num_trials, -1).repeat_interleave(num_theta, dim=1) - # for this to work we construct theta and condition to cover all combinations in the - # trial batch and the theta batch. - theta = torch.cat( + x.transpose_(0, 1) + x_repeated = x.repeat_interleave(num_thetas, dim=1) + + # construct theta and condition to cover all trial-theta combinations + theta_with_condition = torch.cat( [ theta_without_condition.repeat(num_trials, 1), # repeat ABAB - condition.repeat_interleave(num_theta, dim=0), # repeat AABB + condition.repeat_interleave(num_thetas, dim=0), # repeat AABB ], dim=-1, ) with torch.set_grad_enabled(track_gradients): - # Calculate likelihood in one batch. Returns (1, num_trials * theta_batch_size) - log_likelihood_trial_batch = estimator.log_prob(x_expanded, condition=theta) + # Calculate likelihood in one batch. Returns (1, num_trials * num_theta) + log_likelihood_trial_batch = estimator.log_prob( + x_repeated, condition=theta_with_condition + ) # Reshape to (x-trials x parameters), sum over trial-log likelihoods. log_likelihood_trial_sum = log_likelihood_trial_batch.reshape( - num_trials, num_theta - ).sum(0) + num_xs, num_trials, num_thetas + ).sum(1) return log_likelihood_trial_sum diff --git a/tests/mnle_test.py b/tests/mnle_test.py index b242f477f..50c1452a3 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -14,6 +14,7 @@ from sbi.inference.posteriors.vi_posterior import VIPosterior from sbi.inference.potentials.base_potential import BasePotential from sbi.inference.potentials.likelihood_based_potential import ( + _log_likelihood_over_iid_conditions, likelihood_estimator_based_potential, ) from sbi.neural_nets import likelihood_nn @@ -39,6 +40,15 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2. return torch.cat((rts, choices), dim=1) +def wrapped_simulator( + theta_and_condition: Tensor, last_idx_parameters: int = 2 +) -> Tensor: + # simulate with experiment conditions + theta = theta_and_condition[:, :last_idx_parameters] + condition = theta_and_condition[:, last_idx_parameters:] + return mixed_simulator(theta, condition) + + @pytest.mark.mcmc @pytest.mark.gpu @pytest.mark.parametrize("device", ("cpu", "gpu")) @@ -256,14 +266,6 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): num_simulations = 10000 num_samples = 1000 - def sim_wrapper( - theta_and_condition: Tensor, last_idx_parameters: int = 2 - ) -> Tensor: - # simulate with experiment conditions - theta = theta_and_condition[:, :last_idx_parameters] - condition = theta_and_condition[:, last_idx_parameters:] - return mixed_simulator(theta, condition) - proposal = MultipleIndependent( [ Gamma(torch.tensor([1.0]), torch.tensor([0.5])), @@ -274,7 +276,7 @@ def sim_wrapper( ) theta = proposal.sample((num_simulations,)) - x = sim_wrapper(theta) + x = wrapped_simulator(theta) assert x.shape == (num_simulations, 2) num_trials = 10 @@ -285,7 +287,7 @@ def sim_wrapper( condition_o = theta_and_condition[:, 2:] theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1) - x_o = sim_wrapper(theta_and_conditions_o) + x_o = wrapped_simulator(theta_and_conditions_o) mcmc_kwargs = dict( method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate @@ -331,3 +333,82 @@ def sim_wrapper( true_posterior_samples, alg=f"MNLE trained with {num_simulations} simulations", ) + + +@pytest.mark.parametrize("num_thetas", [1, 10]) +@pytest.mark.parametrize("num_trials", [1, 5]) +@pytest.mark.parametrize("num_xs", [1, 3]) +@pytest.mark.parametrize( + "num_conditions", + [ + 1, + pytest.param( + 2, + marks=pytest.mark.xfail( + reason="Batched theta_condition is not " "supported" + ), + ), + ], +) +def test_log_likelihood_over_iid_conditions( + num_thetas, num_trials, num_xs, num_conditions +): + """Test log likelihood over iid conditions using MNLE. + + Args: + num_thetas: batch of theta to condition on. + num_trials: number of i.i.d. trials in x + num_xs: batch of x, e.g., different subjects in a study. + num_conditions: number of batches of conditions, e.g., different conditions + for each x (not implemented yet). + """ + + # train mnle on mixed data + trainer = MNLE( + density_estimator=likelihood_nn(model="mnle", z_score_x=None), + ) + proposal = MultipleIndependent( + [ + Gamma(torch.tensor([1.0]), torch.tensor([0.5])), + Beta(torch.tensor([2.0]), torch.tensor([2.0])), + BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])), + ], + validate_args=False, + ) + + num_simulations = 100 + theta = proposal.sample((num_simulations,)) + x = wrapped_simulator(theta) + estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1) + + # condition on multiple conditions + theta_o = proposal.sample((num_xs,))[:, :2] + + x_o = torch.zeros(num_trials, num_xs, 2) + condition_o = proposal.sample(( + num_conditions, + num_trials, + ))[:, 2:].reshape(num_trials, 1) + for i in range(num_xs): + # simulate with same iid theta but different conditions + x_o[:, i, :] = mixed_simulator(theta_o[i].repeat(num_trials, 1), condition_o) + + # batched conditioning + theta = proposal.sample((num_thetas,))[:, :2] + # x_o has shape (batch, iid, *event) + # condition_o has shape (batch, iid, num_conditions) + ll_batched = _log_likelihood_over_iid_conditions(x_o, theta, condition_o, estimator) + + # looped conditioning + ll_single = [] + for i in range(num_trials): + theta_and_condition = torch.cat( + (theta, condition_o[i].repeat(num_thetas, 1)), dim=1 + ) + x_i = x_o[:, i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1) + ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition)) + ll_single = torch.stack(ll_single).sum(0) # sum over trials + + assert ll_batched.shape == torch.Size([num_xs, num_thetas]) + assert ll_batched.shape == ll_single.shape + assert torch.allclose(ll_batched, ll_single, atol=1e-5)