Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Beta distribution and Sigmoid transform #184

Merged
merged 3 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/examples/constrained.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .planar import Planar
from .power import Power
from .rational_quadratic_spline import RationalQuadraticSpline
from .sigmoid import Sigmoid
from .softplus import SoftPlus
from .tanh import LeakyTanh, Tanh
from .utils import EmbedCondition, Flip, Identity, Invert, Partial, Permute, Reshape
Expand Down Expand Up @@ -40,6 +41,7 @@
"Reshape",
"Scale",
"Scan",
"Sigmoid",
"SoftPlus",
"Stack",
"Tanh",
Expand Down
36 changes: 36 additions & 0 deletions flowjax/bijections/sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Sigmoid bijection."""

from typing import ClassVar

import jax.numpy as jnp
from jax import nn
from jax.scipy.special import logit

from flowjax.bijections.bijection import AbstractBijection


class Sigmoid(AbstractBijection):
r"""Sigmoid bijection :math:`y = \sigma(x) = \frac{1}{1 + \exp(-x)}`.

Args:
shape: The shape of the transform.
"""

shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def transform(self, x, condition=None):
return nn.sigmoid(x)

def transform_and_log_det(self, x, condition=None):
y = nn.sigmoid(x)
log_det = jnp.sum(nn.log_sigmoid(x) + nn.log_sigmoid(-x))
return y, log_det

def inverse(self, y, condition=None):
return logit(y)

def inverse_and_log_det(self, y, condition=None):
x = logit(y)
log_det = -jnp.sum(nn.log_sigmoid(x) + nn.log_sigmoid(-x))
return x, log_det
26 changes: 26 additions & 0 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,29 @@ def __init__(self, concentration: ArrayLike, scale: ArrayLike):
concentration, scale = jnp.broadcast_arrays(concentration, scale)
self.base_dist = _StandardGamma(concentration)
self.bijection = Scale(scale)


class Beta(AbstractDistribution):
"""Beta distribution.

Args:
alpha: The alpha shape parameter.
beta: The beta shape parameter.
"""

alpha: Array | AbstractUnwrappable[Array]
beta: Array | AbstractUnwrappable[Array]
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None

def __init__(self, alpha: ArrayLike, beta: ArrayLike):
alpha, beta = jnp.broadcast_arrays(alpha, beta)
self.alpha = Parameterize(softplus, inv_softplus(alpha))
self.beta = Parameterize(softplus, inv_softplus(beta))
self.shape = alpha.shape

def _sample(self, key, condition=None):
return jr.beta(key, self.alpha, self.beta)

def _log_prob(self, x, condition=None):
return jstats.beta.logpdf(x, self.alpha, self.beta).sum()
2 changes: 2 additions & 0 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Reshape,
Scale,
Scan,
Sigmoid,
SoftPlus,
Stack,
Tanh,
Expand Down Expand Up @@ -67,6 +68,7 @@
"LeakyTanh (broadcast max_val)": lambda: LeakyTanh(1, (2, 3)),
"Loc": lambda: Loc(jnp.arange(DIM)),
"Exp": lambda: Exp((DIM,)),
"Sigmoid": lambda: Sigmoid((DIM,)),
"SoftPlus": lambda: SoftPlus((DIM,)),
"TriangularAffine (lower)": lambda: TriangularAffine(
jnp.arange(DIM),
Expand Down
2 changes: 2 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flowjax.distributions import (
AbstractDistribution,
AbstractTransformed,
Beta,
Cauchy,
Exponential,
Gamma,
Expand Down Expand Up @@ -62,6 +63,7 @@
eqx.filter_vmap(Normal)(jnp.arange(3 * prod(shape)).reshape(3, *shape)),
weights=jnp.arange(3) + 1,
),
"Beta": lambda shape: Beta(jnp.ones(shape), jnp.ones(shape)),
}


Expand Down