Skip to content

Commit

Permalink
Merge pull request #184 from danielward27/beta_and_sigmoid
Browse files Browse the repository at this point in the history
Add Beta distribution and Sigmoid transform
  • Loading branch information
danielward27 authored Oct 7, 2024
2 parents b72252b + 8382970 commit e932a30
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 5 deletions.
10 changes: 5 additions & 5 deletions docs/examples/constrained.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .planar import Planar
from .power import Power
from .rational_quadratic_spline import RationalQuadraticSpline
from .sigmoid import Sigmoid
from .softplus import SoftPlus
from .tanh import LeakyTanh, Tanh
from .utils import EmbedCondition, Flip, Identity, Invert, Partial, Permute, Reshape
Expand Down Expand Up @@ -40,6 +41,7 @@
"Reshape",
"Scale",
"Scan",
"Sigmoid",
"SoftPlus",
"Stack",
"Tanh",
Expand Down
36 changes: 36 additions & 0 deletions flowjax/bijections/sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Sigmoid bijection."""

from typing import ClassVar

import jax.numpy as jnp
from jax import nn
from jax.scipy.special import logit

from flowjax.bijections.bijection import AbstractBijection


class Sigmoid(AbstractBijection):
r"""Sigmoid bijection :math:`y = \sigma(x) = \frac{1}{1 + \exp(-x)}`.
Args:
shape: The shape of the transform.
"""

shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def transform(self, x, condition=None):
return nn.sigmoid(x)

def transform_and_log_det(self, x, condition=None):
y = nn.sigmoid(x)
log_det = jnp.sum(nn.log_sigmoid(x) + nn.log_sigmoid(-x))
return y, log_det

def inverse(self, y, condition=None):
return logit(y)

def inverse_and_log_det(self, y, condition=None):
x = logit(y)
log_det = -jnp.sum(nn.log_sigmoid(x) + nn.log_sigmoid(-x))
return x, log_det
26 changes: 26 additions & 0 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,29 @@ def __init__(self, concentration: ArrayLike, scale: ArrayLike):
concentration, scale = jnp.broadcast_arrays(concentration, scale)
self.base_dist = _StandardGamma(concentration)
self.bijection = Scale(scale)


class Beta(AbstractDistribution):
"""Beta distribution.
Args:
alpha: The alpha shape parameter.
beta: The beta shape parameter.
"""

alpha: Array | AbstractUnwrappable[Array]
beta: Array | AbstractUnwrappable[Array]
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None

def __init__(self, alpha: ArrayLike, beta: ArrayLike):
alpha, beta = jnp.broadcast_arrays(alpha, beta)
self.alpha = Parameterize(softplus, inv_softplus(alpha))
self.beta = Parameterize(softplus, inv_softplus(beta))
self.shape = alpha.shape

def _sample(self, key, condition=None):
return jr.beta(key, self.alpha, self.beta)

def _log_prob(self, x, condition=None):
return jstats.beta.logpdf(x, self.alpha, self.beta).sum()
2 changes: 2 additions & 0 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Reshape,
Scale,
Scan,
Sigmoid,
SoftPlus,
Stack,
Tanh,
Expand Down Expand Up @@ -67,6 +68,7 @@
"LeakyTanh (broadcast max_val)": lambda: LeakyTanh(1, (2, 3)),
"Loc": lambda: Loc(jnp.arange(DIM)),
"Exp": lambda: Exp((DIM,)),
"Sigmoid": lambda: Sigmoid((DIM,)),
"SoftPlus": lambda: SoftPlus((DIM,)),
"TriangularAffine (lower)": lambda: TriangularAffine(
jnp.arange(DIM),
Expand Down
2 changes: 2 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from flowjax.distributions import (
AbstractDistribution,
AbstractTransformed,
Beta,
Cauchy,
Exponential,
Gamma,
Expand Down Expand Up @@ -62,6 +63,7 @@
eqx.filter_vmap(Normal)(jnp.arange(3 * prod(shape)).reshape(3, *shape)),
weights=jnp.arange(3) + 1,
),
"Beta": lambda shape: Beta(jnp.ones(shape), jnp.ones(shape)),
}


Expand Down

0 comments on commit e932a30

Please sign in to comment.