Skip to content

Commit

Permalink
fix: temporary wrappers to fix MADE (#1398)
Browse files Browse the repository at this point in the history
* add temporary MADE wrappers

* test MADEMoG
  • Loading branch information
gmoss13 authored Feb 20, 2025
1 parent 16436e6 commit 68ee1a7
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 3 deletions.
4 changes: 2 additions & 2 deletions sbi/neural_nets/net_builders/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.neural_nets.estimators import NFlowsFlow, ZukoFlow
from sbi.utils.nn_utils import get_numel
from sbi.utils.nn_utils import MADEMoGWrapper, get_numel
from sbi.utils.sbiutils import (
standardizing_net,
standardizing_transform,
Expand Down Expand Up @@ -77,7 +77,7 @@ def build_made(
standardizing_net(batch_y, structured_y), embedding_net
)

distribution = distributions_.MADEMoG(
distribution = MADEMoGWrapper(
features=x_numel,
hidden_features=hidden_features,
context_features=y_numel,
Expand Down
132 changes: 132 additions & 0 deletions sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import Optional
from warnings import warn

import nflows.nn.nde.made as made
import numpy as np
import torch
import torch.nn.functional as F
from pyknos.nflows import distributions as distributions_
from torch import Tensor, nn


Expand Down Expand Up @@ -62,3 +67,130 @@ def check_net_device(
return net.to(device)
else:
return net


"""
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
"""


class MADEWrapper(made.MADE):
"""Implementation of MADE.
It can use either feedforward blocks or residual blocks (default is residual).
Optionally, it can use batch norm or dropout within blocks (default is no).
"""

def __init__(
self,
features,
hidden_features,
context_features=None,
num_blocks=2,
output_multiplier=1,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
):
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")
super().__init__(
features + 1,
hidden_features,
context_features,
num_blocks,
output_multiplier,
use_residual_blocks,
random_mask,
activation,
dropout_probability,
use_batch_norm,
)

def forward(self, inputs, context=None):
# add dummy input to ensure all dims conditioned on context.
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
concat_input = torch.cat((dummy_input, inputs), dim=-1)
outputs = super().forward(concat_input, context)
# the final layer of MADE produces self.output_multiplier outputs for each
# input dimension, in order. We only want the outputs corresponding to the
# real inputs, so we discard the first self.output_multiplier outputs.
return outputs[..., self.output_multiplier :]


"""
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
"""


class MADEMoGWrapper(distributions_.MADEMoG):
def __init__(
self,
features,
hidden_features,
context_features,
num_blocks=2,
num_mixture_components=1,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
custom_initialization=False,
):
super().__init__(
features + 1,
hidden_features,
context_features,
num_blocks,
num_mixture_components,
use_residual_blocks,
random_mask,
activation,
dropout_probability,
use_batch_norm,
custom_initialization,
)

def _log_prob(self, inputs, context=None):
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
concat_inputs = torch.cat((dummy_input, inputs), dim=-1)

outputs = self._made.forward(concat_inputs, context=context)
outputs = outputs.reshape(
*concat_inputs.shape, self._made.num_mixture_components, 3
)

logits, means, unconstrained_stds = (
outputs[..., 0],
outputs[..., 1],
outputs[..., 2],
)
# remove first dimension of means, unconstrained_stds
logits = logits[..., 1:, :]
means = means[..., 1:, :]
unconstrained_stds = unconstrained_stds[..., 1:, :]

log_mixture_coefficients = torch.log_softmax(logits, dim=-1)
stds = F.softplus(unconstrained_stds) + self._made.epsilon

log_prob = torch.sum(
torch.logsumexp(
log_mixture_coefficients
- 0.5
* (
np.log(2 * np.pi)
+ 2 * torch.log(stds)
+ ((inputs[..., None] - means) / stds) ** 2
),
dim=-1,
),
dim=-1,
)
return log_prob

def _sample(self, num_samples, context=None):
samples = self._made.sample(num_samples, context=context)
return samples[..., 1:]
2 changes: 2 additions & 0 deletions tests/density_estimator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from sbi.neural_nets.net_builders import (
build_categoricalmassestimator,
build_made,
build_maf,
build_maf_rqs,
build_mdn,
Expand All @@ -36,6 +37,7 @@

# List of all density estimator builders for testing.
model_builders = [
build_made,
build_mdn,
build_maf,
build_maf_rqs,
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def simulator(theta):
@pytest.mark.slow
@pytest.mark.parametrize(
"density_estimator",
["mdn", "maf", "maf_rqs", "nsf", "zuko_maf", "zuko_nsf"],
["made", "mdn", "maf", "maf_rqs", "nsf", "zuko_maf", "zuko_nsf"],
)
def test_density_estimators_on_linearGaussian(density_estimator):
"""Test NPE with different density estimators on linear Gaussian example."""
Expand Down

0 comments on commit 68ee1a7

Please sign in to comment.