Skip to content

Commit

Permalink
Fixing suggestions and adding dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jan 20, 2025
1 parent e41a771 commit 45b8fcf
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 44 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dev = [
"pytest",
"pytest-cov",
"pytest-xdist",
"pytest-harvest",
"torchtestcase",
]

Expand Down
91 changes: 52 additions & 39 deletions tests/bm_test.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
import pytest
import torch
from pytest_harvest import ResultsBag

from sbi.inference import FMPE, NLE, NPE, NPSE, NRE
from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.trainers.npe import NPE_C
from sbi.inference.trainers.nre import BNRE, NRE_A, NRE_B, NRE_C
from sbi.utils.metrics import c2st

from .mini_sbibm import get_task
from .mini_sbibm.base_task import Task

# NOTE: This might can be improved...
# Global settings
SEED = 0
TASKS = ["two_moons", "linear_mvg_2d", "gaussian_linear", "slcp"]
NUM_SIMULATIONS = 2000
EVALUATION_POINTS = 4 # Currently only 3 observation tested for speed
NUM_ROUNDS_SEQUENTIAL = 2
EVALUATION_POINT_SEQUENTIAL = 1
TRAIN_KWARGS = {}

# Density estimators to test
DENSITY_estimators = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
DENSITY_ESTIMATORS = ["mdn", "made", "maf", "nsf", "maf_rqs"] # "Kinda exhaustive"
CLASSIFIERS = ["mlp", "resnet"]
NNS = ["mlp", "resnet"]
SCORE_ESTIMATORS = ["mlp", "ada_mlp"]

# Benchmarking method groups
# Benchmarking method groups i.e. what to run for different --bm-mode
METHOD_GROUPS = {
"none": [NPE, NRE, NLE, FMPE, NPSE],
"npe": [NPE],
Expand All @@ -36,7 +41,7 @@
}
METHOD_PARAMS = {
"none": [{}],
"npe": [{"density_estimator": de} for de in DENSITY_estimators],
"npe": [{"density_estimator": de} for de in DENSITY_ESTIMATORS],
"nle": [{"density_estimator": de} for de in ["maf", "nsf"]],
"nre": [{"classifier": cl} for cl in CLASSIFIERS],
"fmpe": [{"density_estimator": nn} for nn in NNS],
Expand Down Expand Up @@ -106,7 +111,7 @@ def pytest_generate_tests(metafunc):
metafunc.parametrize("extra_kwargs", kwargs_group)


def standard_eval_c2st_loop(posterior, task) -> float:
def standard_eval_c2st_loop(posterior: NeuralPosterior, task: Task) -> float:
"""
Evaluates the C2ST metric for the given posterior and task.
Expand All @@ -117,18 +122,23 @@ def standard_eval_c2st_loop(posterior, task) -> float:
Returns:
float: The mean C2ST value.
"""
metrics = []
c2st_scores = []
for i in range(1, EVALUATION_POINTS):
c2st_val = eval_c2st(posterior, task, i)
metrics.append(c2st_val)
c2st_scores.append(c2st_val)

mean_c2st = sum(metrics) / len(metrics)
mean_c2st = sum(c2st_scores) / len(c2st_scores)
# Convert to float rounded to 3 decimal places
mean_c2st = float(f"{mean_c2st:.3f}")
return mean_c2st


def eval_c2st(posterior, task, i: int) -> float:
def eval_c2st(
posterior: NeuralPosterior,
task: Task,
idx_observation: int,
num_samples: int = 1000,
) -> float:
"""
Evaluates the C2ST metric for a specific observation.
Expand All @@ -140,33 +150,35 @@ def eval_c2st(posterior, task, i: int) -> float:
Returns:
float: The C2ST value.
"""
x_o = task.get_observation(i)
posterior_samples = task.get_reference_posterior_samples(i)
approx_posterior_samples = posterior.sample((1000,), x=x_o)
x_o = task.get_observation(idx_observation)
posterior_samples = task.get_reference_posterior_samples(idx_observation)
approx_posterior_samples = posterior.sample((num_samples,), x=x_o)
if isinstance(approx_posterior_samples, tuple):
approx_posterior_samples = approx_posterior_samples[0]
c2st_val = c2st(posterior_samples[:1000], approx_posterior_samples)
return c2st_val
assert posterior_samples.shape[0] >= num_samples, "Not enough reference samples"
c2st_val = c2st(posterior_samples[:num_samples], approx_posterior_samples)
return float(c2st_val)


def amortized_inference_eval(
method, task_name: str, extra_kwargs: dict, results_bag
inference_method, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
) -> None:
"""
Performs amortized inference evaluation.
Args:
method: The inference method.
task_name (str): The name of the task.
extra_kwargs (dict): Additional keyword arguments for the method.
results_bag: The results bag to store evaluation results.
task_name: The name of the task.
extra_kwargs: Additional keyword arguments for the method.
results_bag: The results bag to store evaluation results. Subclass of dict, but
allows item assignment with dot notation.
"""
torch.manual_seed(SEED)
task = get_task(task_name)
thetas, xs = task.get_data(NUM_SIMULATIONS)
prior = task.get_prior()

inference = method(prior, **extra_kwargs)
inference = inference_method(prior, **extra_kwargs)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

posterior = inference.build_posterior()
Expand All @@ -177,11 +189,11 @@ def amortized_inference_eval(
results_bag.metric = mean_c2st
results_bag.num_simulations = NUM_SIMULATIONS
results_bag.task_name = task_name
results_bag.method = method.__name__ + str(extra_kwargs)
results_bag.method = inference_method.__name__ + str(extra_kwargs)


def sequential_inference_eval(
method, task_name: str, extra_kwargs: dict, results_bag
method, task_name: str, extra_kwargs: dict, results_bag: ResultsBag
) -> None:
"""
Performs sequential inference evaluation.
Expand All @@ -194,28 +206,29 @@ def sequential_inference_eval(
"""
torch.manual_seed(SEED)
task = get_task(task_name)
num_simulations1 = NUM_SIMULATIONS // 2
thetas, xs = task.get_data(num_simulations1)
num_simulations = NUM_SIMULATIONS // NUM_ROUNDS_SEQUENTIAL
thetas, xs = task.get_data(num_simulations)
prior = task.get_prior()
idx_eval = 1
idx_eval = EVALUATION_POINT_SEQUENTIAL
x_o = task.get_observation(idx_eval)
simulator = task.get_simulator()

# Round 1
inference = method(prior, **extra_kwargs)
_ = inference.append_simulations(thetas, xs).train(**TRAIN_KWARGS)

proposal = inference.build_posterior().set_default_x(task.get_observation(idx_eval))
num_simulations2 = NUM_SIMULATIONS - num_simulations1
thetas2 = proposal.sample((num_simulations2,))
xs2 = task.get_simulator()(thetas2)
for _ in range(NUM_ROUNDS_SEQUENTIAL - 1):
proposal = inference.build_posterior().set_default_x(x_o)
thetas_i = proposal.sample((num_simulations,))
xs_i = simulator(thetas_i)
if "npe" in method.__name__.lower():
# NPE_C requires a Gaussian prior
_ = inference.append_simulations(thetas_i, xs_i, proposal=proposal).train(
**TRAIN_KWARGS
)
else:
inference.append_simulations(thetas_i, xs_i).train(**TRAIN_KWARGS)

# Round 2
if "npe" in method.__name__.lower():
# NPE_C requires a Gaussian prior
_ = inference.append_simulations(thetas2, xs2, proposal=proposal).train(
**TRAIN_KWARGS
)
else:
_ = inference.append_simulations(thetas2, xs2).train(**TRAIN_KWARGS)
posterior = inference.build_posterior()

c2st_val = eval_c2st(posterior, task, idx_eval)
Expand All @@ -229,7 +242,7 @@ def sequential_inference_eval(

@pytest.mark.benchmark
@pytest.mark.parametrize("task_name", TASKS, ids=str)
def test_benchmark_standard(
def test_benchmark(
inference_method,
task_name: str,
results_bag,
Expand All @@ -241,10 +254,10 @@ def test_benchmark_standard(
Args:
inference_method: The inference method to test.
task_name (str): The name of the task.
task_name: The name of the task.
results_bag: The results bag to store evaluation results.
extra_kwargs (dict): Additional keyword arguments for the method.
benchmark_mode (str): The benchmark mode.
extra_kwargs: Additional keyword arguments for the method.
benchmark_mode: The benchmark mode.
"""
if benchmark_mode in ["snpe", "snle", "snre"]:
sequential_inference_eval(
Expand Down
12 changes: 7 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@ def pytest_collection_modifyitems(config, items):
)
if not gpu_device_available:
skip_gpu = pytest.mark.skip(reason="No devices available")
skip_bm = pytest.mark.skip(reason="Benchmarking disabled")

for item in items:
if "gpu" in item.keywords:
item.add_marker(skip_gpu)

if not config.getoption("--bm"):
if not config.getoption("--bm"):
# Skip marked benchmarking tests
skip_bm = pytest.mark.skip(reason="Benchmarking disabled")
for item in items:
if "benchmark" in item.keywords:
item.add_marker(skip_bm)

# Filter tests to only those with the 'benchmark' marker
if config.getoption("--bm"):
else:
# Filter tests to only those with the 'benchmark' marker
filtered_items = []
for item in items:
Expand Down

0 comments on commit 45b8fcf

Please sign in to comment.