Skip to content

Commit

Permalink
new ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jan 30, 2025
1 parent 843ce7d commit 384f36f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 16 deletions.
4 changes: 1 addition & 3 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zuko.transforms import FreeFormJacobianTransform

from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.score_fn_iid import FNPEScoreFn
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
Expand All @@ -19,8 +20,6 @@
from sbi.utils.sbiutils import mcmc_transform, within_support
from sbi.utils.torchutils import ensure_theta_batched

from sbi.inference.potentials.score_fn_iid import FNPEScoreFn


def score_estimator_based_potential(
score_estimator: ConditionalScoreEstimator,
Expand Down Expand Up @@ -250,4 +249,3 @@ def f(t, x):
exact=exact,
)
return transform

10 changes: 5 additions & 5 deletions sbi/inference/potentials/score_fn_iid.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Callable, Optional, Union
from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from torch.distributions import Distribution

from abc import abstractmethod
from typing import Callable, Optional

import torch
from torch import Tensor
from torch.distributions import Distribution

from sbi.neural_nets.estimators.score_estimator import ConditionalScoreEstimator
from sbi.utils.torchutils import ensure_theta_batched


Expand Down Expand Up @@ -131,4 +131,4 @@ def __call__(
# Accumulate
score = (1 - N) * prior_score + base_score.sum(-2, keepdim=True)

return score
return score
4 changes: 2 additions & 2 deletions sbi/inference/potentials/score_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch.distributions import Distribution, Independent, Normal, Uniform
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal

# Automatic denoising -----------------------------------------------------

Expand Down Expand Up @@ -129,4 +129,4 @@ def marginalize_gaussian(p: Normal, m: Tensor, s: Tensor) -> Normal:
var = (m * std0) ** 2 + s**2
std = var**0.5

return Normal(mu, std)
return Normal(mu, std)
12 changes: 6 additions & 6 deletions tests/linearGaussian_npse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ def simulator(theta):
check_c2st(samples, target_samples, alg="npse_different_dims_and_resume_training")


@pytest.mark.xfail(
reason="iid_bridge not working.",
raises=AssertionError,
strict=True,
match="Score accumulation*",
)
# @pytest.mark.xfail(
# reason="iid_bridge not working.",
# raises=AssertionError,
# strict=True,
# match="Score accumulation*",
# )
@pytest.mark.parametrize("num_trials", [2, 10])
def test_npse_iid_inference(num_trials):
"""Test whether NPSE infers well a simple example with available ground truth."""
Expand Down

0 comments on commit 384f36f

Please sign in to comment.