Skip to content

Commit

Permalink
jac method (still needs feasible Lambda projection to work)
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Feb 14, 2025
1 parent 5195417 commit 5bcf427
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 74 deletions.
23 changes: 6 additions & 17 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from zuko.transforms import FreeFormJacobianTransform

from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.score_fn_iid import FNPEScoreFn, GaussCorrectedScoreFn
from sbi.inference.potentials.score_fn_iid import get_iid_method
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
Expand Down Expand Up @@ -79,7 +79,7 @@ def set_x(
self,
x_o: Optional[Tensor],
x_is_iid: Optional[bool] = False,
iid_method: Optional[str] = "fnpe",
iid_method: str = "fnpe",
rebuild_flow: Optional[bool] = True,
):
super().set_x(x_o, x_is_iid)
Expand Down Expand Up @@ -180,21 +180,10 @@ def gradient(
)
else:
assert self.prior is not None, "Prior is required for iid methods."
# NOTE: Add here different methods for accumulating the score.
# TODO: Warn for FNPE -> Kinda needs a "corrector"
if self.iid_method == "fnpe":
score_fn_iid = FNPEScoreFn(
self.score_estimator, self.prior, device=self.device
)
elif self.iid_method == "gauss":
score_fn_iid = GaussCorrectedScoreFn(
self.score_estimator, self.prior, 2 * torch.ones_like(theta[-1])
)
else:
raise NotImplementedError(
f"Method {self.iid_method} for iid score accumulation not \
implemented."
)

method_iid = get_iid_method(self.iid_method)
score_fn_iid = method_iid(self.score_estimator, self.prior)

score = score_fn_iid(theta, self.x_o, time)

return score
Expand Down
142 changes: 117 additions & 25 deletions sbi/inference/potentials/score_fn_iid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Callable, Optional
from typing import Callable, Optional, Type

import torch
from torch import Tensor
Expand All @@ -16,13 +16,32 @@
from sbi.utils.torchutils import ensure_theta_batched


IID_METHODS = {}


def get_iid_method(name: str) -> Type["ScoreFnIID"]:
if name not in IID_METHODS:
raise NotImplementedError(
f"Method {name} for iid score accumulation not implemented."
)
return IID_METHODS[name]


def register_iid_method(name: str) -> Callable:
def decorator(cls):
IID_METHODS[name] = cls
return cls

return decorator


class ScoreFnIID:
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
score_estimator: "ConditionalScoreEstimator",
prior: Distribution,
device: str = "cpu",
):
) -> None:
r"""
Initializes the ScoreFnIID class.
Expand Down Expand Up @@ -72,15 +91,15 @@ def prior_score_fn(self, theta: Tensor) -> Tensor:
)[0]
return prior_score


@register_iid_method("fnpe")
class FNPEScoreFn(ScoreFnIID):
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
score_estimator: "ConditionalScoreEstimator",
prior: Distribution,
device: str = "cpu",
prior_score_weight: Optional[Callable[[Tensor], Tensor]] = None,
):
) -> None:
r"""
Initializes the FNPEScoreFn class.
Expand Down Expand Up @@ -123,7 +142,7 @@ def __call__(
if time is None:
time = torch.tensor([self.score_estimator.t_min])

# NOTE: If this always works needs to be testd
# NOTE: If this always works needs to be tested

N = inputs.shape[-2]

Expand All @@ -144,7 +163,7 @@ def __call__(
class AbstractGaussCorrectedScoreFn(ScoreFnIID):
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
score_estimator: "ConditionalScoreEstimator",
prior: Distribution,
) -> None:
r"""Initializes the AbstractGaussCorrectedScoreFn class.
Expand Down Expand Up @@ -263,6 +282,7 @@ def marginal_denoising_prior_precision_fn(
m = self.score_estimator.mean_t_fn(time)
std = self.score_estimator.std_fn(time)
p_denoise = self.denoising_prior(m, std, inputs)
# TODO: If multivariate prior this will break
return 1 / p_denoise.variance

def __call__(
Expand All @@ -278,12 +298,14 @@ def __call__(
Returns:
Corrected score function.
"""
# TODO We can assume here a fixed 3dim shape of inputs [b,N,d]
# TODO We can assume here a fixed format of conditions [N,...]
N = conditions.shape[0]
base_score = self.score_estimator(inputs, conditions, time, **kwargs)
prior_score = self.marginal_prior_score_fn(time, inputs)

print(f"base_score: {base_score.shape}")
print(f"prior_score: {prior_score.shape}")
# print(f"prior_score: {prior_score.shape}")

# Marginal prior precision
prior_precision = self.marginal_denoising_prior_precision_fn(time, inputs)
Expand All @@ -297,7 +319,7 @@ def __call__(

# Total precision
term1 = (1 - N) * prior_precision
term2 = torch.sum(posterior_precisions, dim=-2, keepdim=True).unsqueeze(0)
term2 = torch.sum(posterior_precisions, dim=1)
print(f"term1: {term1.shape}")
print(f"term2: {term2.shape}")
Lam = add_diag_or_dense(term1, term2, batch_dims=1)
Expand All @@ -308,36 +330,48 @@ def __call__(
weighted_prior_score = mv_diag_or_dense(
prior_precision, prior_score, batch_dims=2
)
print(f"weighted_prior_score: {weighted_prior_score.shape}")
# print(f"weighted_prior_score: {weighted_prior_score.shape}")
weighted_posterior_scores = mv_diag_or_dense(
posterior_precisions.unsqueeze(0), base_score, batch_dims=2
posterior_precisions, base_score, batch_dims=2
)

# Accumulate the scores
score = (1 - N) * weighted_prior_score + torch.sum(
weighted_posterior_scores, dim=-2, keepdim=True
score = (1 - N) * weighted_prior_score.sum(dim=1) + torch.sum(
weighted_posterior_scores, dim=1
)
print(f"score: {score.shape}")
# print(f"score: {score.shape}")
# Solve the linear system
score = solve_diag_or_dense(Lam, score, batch_dims=2)

return score
Lam = Lam + 1e-1 * torch.eye(Lam.shape[-1])[None]
score = solve_diag_or_dense(Lam, score, batch_dims=1)

return score.reshape(inputs.shape)

@register_iid_method("gauss")
class GaussCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
def __init__(
self,
score_estimator: ConditionalScoreEstimator,
score_estimator: "ConditionalScoreEstimator",
prior: Distribution,
posterior_precision: Tensor,
posterior_precision: Optional[Tensor] = None,
scale_from_prior_precision: float = 1.5, # Renamed parameter
) -> None:
r"""Initializes the GaussCorrectedScoreFn class.
r"""
Initializes the GaussCorrectedScoreFn class.
Args:
score_estimator: The neural network modelling the score.
score_estimator: The neural network modeling the score.
prior: The prior distribution.
posterior_precision: Optional preset posterior precision.
scale_from_prior_precision: Scaling factor for the posterior precision if not provided.
"""
super().__init__(score_estimator, prior)

if posterior_precision is None:
prior_samples = self.prior.sample((1000,))
prior_precision_estimate = 1 / torch.var(prior_samples, dim=0)
posterior_precision = scale_from_prior_precision * prior_precision_estimate

self.posterior_precision = posterior_precision

def posterior_precision_est_fn(self, x_o: Tensor) -> Tensor:
Expand All @@ -351,12 +385,70 @@ def posterior_precision_est_fn(self, x_o: Tensor) -> Tensor:
"""
return self.posterior_precision


@register_iid_method("auto_gauss")
class AutoGaussCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
# TODO: Move over..
pass


@register_iid_method("jac_gauss")
class JacCorrectedScoreFn(AbstractGaussCorrectedScoreFn):
pass
# TODO: Move over...
def posterior_precision_est_fn(self, conditions: Tensor) -> Tensor:
r"""
Estimates the posterior precision for a Jacobian-based correction.
Args:
conditions: Observed data.
Returns:
Estimated posterior precision.
"""
raise ValueError("This method is not used for JacCorrectedScoreFn.")

def marginal_posterior_precision_est_fn(
self, time: Tensor, inputs: Tensor, conditions: Tensor, N: int
) -> Tensor:
r"""
Estimates the marginal posterior precision using the Jacobian of the score function.
Args:
time: Time tensor.
inputs: Parameter tensor.
conditions: Observed data.
N: Number of samples.
Returns:
Estimated marginal posterior precision.
"""
d = inputs.shape[-1]
with torch.enable_grad():
# NOTE: torch.func can be realtively unstable...
jac_fn = torch.func.jacrev(
lambda x: self.score_estimator(x, conditions, time)
)
jac_fn = torch.func.vmap(torch.func.vmap(jac_fn))
jac = jac_fn(inputs).squeeze(1)
print("jac", jac.shape)
# jac = torch.func.vmap(
# torch.func.vmap(
# torch.func.jacrev(
# lambda x: self.score_estimator(x, conditions, time)
# )
# )
# )(inputs)

# Must be symmetrical
jac = 0.5 * (jac + jac.transpose(-1, -2))
print("jac", jac.shape)
m = self.score_estimator.mean_t_fn(time)
std = self.score_estimator.std_fn(time)
cov0 = std**2 * jac + torch.eye(d)[None, None, :, :]

denoising_posterior_precision = m**2 / std**2 + torch.inverse(cov0)
# Project to psd
eigvals, eigvecs = torch.linalg.eigh(denoising_posterior_precision)
eigvals = torch.clamp(eigvals, min=0.1)
denoising_posterior_precision = (
eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-1, -2)
)

return denoising_posterior_precision
Loading

0 comments on commit 5bcf427

Please sign in to comment.