Skip to content

Commit

Permalink
feat: refactoring and new features for NPSE (#1370)
Browse files Browse the repository at this point in the history
* npse MAP

* set default enable_Transform to True

* sampling via diffusion twice

* batched sampling for score-based posteriors

* add test for score batched sampling

* better convergence checks
  • Loading branch information
gmoss13 authored Feb 20, 2025
1 parent aa05585 commit 16436e6
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 133 deletions.
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def sample(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -176,7 +176,7 @@ def sample_batched(
)

samples = rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
Expand Down Expand Up @@ -373,7 +373,7 @@ def leakage_correction(
def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
return rejection.accept_reject_sample(
proposal=self.posterior_estimator,
proposal=self.posterior_estimator.sample,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_rejection_samples,
show_progress_bars=show_progress_bars,
Expand Down
170 changes: 133 additions & 37 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.potentials.score_based_potential import (
CallableDifferentiablePotentialFunction,
PosteriorScoreBasedPotential,
score_estimator_based_potential,
)
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.rejection import rejection
from sbi.samplers.score.correctors import Corrector
from sbi.samplers.score.diffuser import Diffuser
from sbi.samplers.score.predictors import Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.sbiutils import gradient_ascent, within_support
from sbi.utils.torchutils import ensure_theta_batched


Expand All @@ -46,7 +49,7 @@ def __init__(
prior: Distribution,
max_sampling_batch_size: int = 10_000,
device: Optional[str] = None,
enable_transform: bool = False,
enable_transform: bool = True,
sample_with: str = "sde",
):
"""
Expand Down Expand Up @@ -110,7 +113,6 @@ def sample(
Args:
sample_shape: Shape of the samples to be drawn.
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
predictor: The predictor for the diffusion-based sampler. Can be a string or
a custom predictor following the API in `sbi.samplers.score.predictors`.
Currently, only `euler_maruyama` is implemented.
Expand All @@ -136,23 +138,39 @@ def sample(

x = self._x_else_default_x(x)
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
self.potential_fn.set_x(x)
self.potential_fn.set_x(x, x_is_iid=True)

num_samples = torch.Size(sample_shape).numel()

if self.sample_with == "ode":
samples = self.sample_via_zuko(sample_shape=sample_shape, x=x)
elif self.sample_with == "sde":
samples = self._sample_via_diffusion(
sample_shape=sample_shape,
predictor=predictor,
corrector=corrector,
predictor_params=predictor_params,
corrector_params=corrector_params,
steps=steps,
ts=ts,
samples = rejection.accept_reject_sample(
proposal=self.sample_via_ode,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
)[0]
elif self.sample_with == "sde":
proposal_sampling_kwargs = {
"predictor": predictor,
"corrector": corrector,
"predictor_params": predictor_params,
"corrector_params": corrector_params,
"steps": steps,
"ts": ts,
"max_sampling_batch_size": max_sampling_batch_size,
"show_progress_bars": show_progress_bars,
}
samples = rejection.accept_reject_sample(
proposal=self._sample_via_diffusion,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
show_progress_bars=show_progress_bars,
)
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]

samples = samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def _sample_via_diffusion(
Expand All @@ -171,7 +189,6 @@ def _sample_via_diffusion(
Args:
sample_shape: Shape of the samples to be drawn.
x: Deprecated - use `.set_default_x()` prior to `.sample()`.
predictor: The predictor for the diffusion-based sampler. Can be a string or
a custom predictor following the API in `sbi.samplers.score.predictors`.
Currently, only `euler_maruyama` is implemented.
Expand Down Expand Up @@ -222,11 +239,10 @@ def _sample_via_diffusion(
)
samples = torch.cat(samples, dim=0)[:num_samples]

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def sample_via_zuko(
def sample_via_ode(
self,
x: Tensor,
sample_shape: Shape = torch.Size(),
) -> Tensor:
r"""Return samples from posterior distribution with probability flow ODE.
Expand All @@ -243,10 +259,12 @@ def sample_via_zuko(
"""
num_samples = torch.Size(sample_shape).numel()

flow = self.potential_fn.get_continuous_normalizing_flow(condition=x)
flow = self.potential_fn.get_continuous_normalizing_flow(
condition=self.potential_fn.x_o
)
samples = flow.sample(torch.Size((num_samples,)))

return samples.reshape(sample_shape + self.score_estimator.input_shape)
return samples

def log_prob(
self,
Expand Down Expand Up @@ -291,19 +309,73 @@ def sample_batched(
self,
sample_shape: torch.Size,
x: Tensor,
predictor: Union[str, Predictor] = "euler_maruyama",
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[Dict] = None,
corrector_params: Optional[Dict] = None,
steps: int = 500,
ts: Optional[Tensor] = None,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for ScorePosterior."
num_samples = torch.Size(sample_shape).numel()
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
condition_dim = len(self.score_estimator.condition_shape)
batch_shape = x.shape[:-condition_dim]
batch_size = batch_shape.numel()
self.potential_fn.set_x(x)

max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)

if self.sample_with == "ode":
samples = rejection.accept_reject_sample(
proposal=self.sample_via_ode,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
num_xos=batch_size,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"x": x},
)[0]
samples = samples.reshape(
sample_shape + batch_shape + self.score_estimator.input_shape
)
elif self.sample_with == "sde":
proposal_sampling_kwargs = {
"predictor": predictor,
"corrector": corrector,
"predictor_params": predictor_params,
"corrector_params": corrector_params,
"steps": steps,
"ts": ts,
"max_sampling_batch_size": max_sampling_batch_size,
"show_progress_bars": show_progress_bars,
}
samples = rejection.accept_reject_sample(
proposal=self._sample_via_diffusion,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
num_xos=batch_size,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]
samples = samples.reshape(
sample_shape + batch_shape + self.score_estimator.input_shape
)

return samples

def map(
self,
x: Optional[Tensor] = None,
num_iter: int = 1000,
num_to_optimize: int = 1000,
learning_rate: float = 1e-5,
learning_rate: float = 0.01,
init_method: Union[str, Tensor] = "posterior",
num_init_samples: int = 1000,
save_best_every: int = 1000,
Expand Down Expand Up @@ -351,17 +423,41 @@ def map(
Returns:
The MAP estimate.
"""
raise NotImplementedError(
"MAP estimation is currently not working accurately for ScorePosterior."
)
return super().map(
x=x,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
init_method=init_method,
num_init_samples=num_init_samples,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
force_update=force_update,
)
if x is not None:
raise ValueError(
"Passing `x` directly to `.map()` has been deprecated."
"Use `.self_default_x()` to set `x`, and then run `.map()` "
)

if self.default_x is None:
raise ValueError(
"Default `x` has not been set."
"To set the default, use the `.set_default_x()` method."
)

if self._map is None or force_update:
self.potential_fn.set_x(self.default_x)
callable_potential_fn = CallableDifferentiablePotentialFunction(
self.potential_fn
)
if init_method == "posterior":
inits = self.sample((num_init_samples,))
elif init_method == "proposal":
inits = self.proposal.sample((num_init_samples,)) # type: ignore
elif isinstance(init_method, Tensor):
inits = init_method
else:
raise ValueError

self._map = gradient_ascent(
potential_fn=callable_potential_fn,
inits=inits,
theta_transform=self.theta_transform,
num_iter=num_iter,
num_to_optimize=num_to_optimize,
learning_rate=learning_rate,
save_best_every=save_best_every,
show_progress_bars=show_progress_bars,
)[0]

return self._map
Loading

0 comments on commit 16436e6

Please sign in to comment.