Skip to content

Commit

Permalink
Add AsymmetricAffine and add some docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jan 30, 2025
1 parent 95b9780 commit 27482e6
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 17 deletions.
3 changes: 2 additions & 1 deletion flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .power import Power
from .rational_quadratic_spline import RationalQuadraticSpline
from .sigmoid import Sigmoid
from .softplus import SoftPlus
from .softplus import SoftPlus, AsymmetricAffine
from .tanh import LeakyTanh, Tanh
from .utils import (
EmbedCondition,
Expand All @@ -33,6 +33,7 @@
"AdditiveCondition",
"Affine",
"AbstractBijection",
"AsymmetricAffine",
"BlockAutoregressiveNetwork",
"Chain",
"Concatenate",
Expand Down
1 change: 1 addition & 0 deletions flowjax/bijections/bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from equinox import AbstractVar
from jaxtyping import Array, ArrayLike
from paramax import unwrap
import jax

from flowjax.utils import _get_ufunc_signature, arraylike_to_array

Expand Down
1 change: 1 addition & 0 deletions flowjax/bijections/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax
import paramax
from jaxtyping import PRNGKeyArray

Expand Down
1 change: 1 addition & 0 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable

import equinox as eqx
from jax import Array
import jax.numpy as jnp
from jax.lax import scan
from jax.tree_util import tree_leaves, tree_map
Expand Down
49 changes: 47 additions & 2 deletions flowjax/bijections/orthogonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@


class Neg(AbstractBijection):
"""A bijection that negates its input (multiplies by -1).
This is a simple bijection that flips the sign of all elements in the input array.
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
"""
shape: tuple[int, ...]
cond_shape = None

def __init__(self, shape):
"""Initialize the MvScale bijection with `params`."""
self.shape = shape

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
Expand All @@ -21,12 +28,28 @@ def inverse_and_log_det(self, y: Array, condition: Array | None = None):


class Householder(AbstractBijection):
"""A Householder reflection bijection.
This bijection implements a Householder reflection, which is a linear
transformation that reflects vectors across a hyperplane defined by a normal
vector (params). The transformation is its own inverse and volume-preserving
(determinant = ±1).
Given a unit vector v, the transformation is:
x → x - 2(x·v)v
Attributes:
shape: Shape of the input/output vectors
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
params: Normal vector defining the reflection hyperplane. The vector is
normalized in the transformation, so scaling params will have no effect
on the bijection.
"""
shape: tuple[int, ...]
params: Array
cond_shape = None

def __init__(self, params: Array):
"""Initialize the MvScale bijection with `params`."""
self.shape = (params.shape[-1],)
self.params = params

Expand All @@ -43,8 +66,30 @@ def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
def inverse_and_log_det(self, y: Array, condition: Array | None = None):
return self._householder(y, self.params), jnp.zeros(())

def inverse_gradient_and_val(
self,
y: Array,
y_grad: Array,
y_logp: Array,
condition: Array | None = None,
) -> tuple[Array, Array, Array]:
x, logdet = self.inverse_and_log_det(y)
x_grad = self._householder(y_grad, params=self.params)
return (x, x_grad, y_logp - logdet)


class DCT(AbstractBijection):
"""Discrete Cosine Transform (DCT) bijection.
This bijection applies the DCT or its inverse along a specified axis.
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
axis: Axis along which to apply the DCT
norm: Normalization method, fixed to 'ortho' to ensure bijectivity
"""

shape: tuple[int, ...]
cond_shape = None
axis: int
Expand Down
125 changes: 123 additions & 2 deletions flowjax/bijections/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from typing import ClassVar

import jax
import jax.numpy as jnp
from jax.nn import softplus
from jax.nn import softplus, soft_sign
from jaxtyping import Array, ArrayLike
from paramax import AbstractUnwrappable, Parameterize, unwrap
from paramax.utils import inv_softplus

from flowjax.bijections.bijection import AbstractBijection

from flowjax.utils import arraylike_to_array

class SoftPlus(AbstractBijection):
r"""Transforms to positive domain using softplus :math:`y = \log(1 + \exp(x))`."""
Expand All @@ -20,3 +24,120 @@ def transform_and_log_det(self, x, condition=None):
def inverse_and_log_det(self, y, condition=None):
x = jnp.log(-jnp.expm1(-y)) + y
return x, softplus(-x).sum()


class AsymmetricAffine(AbstractBijection):
"""An asymmetric bijection that applies different scaling factors for
positive and negative inputs.
This bijection implements a continuous, differentiable transformation that
scales positive and negative inputs differently while maintaining smoothness
at zero. It's particularly useful for modeling data with different variances
in positive and negative regions.
The forward transformation is defined as:
y = σ θ x for x ≥ 0
y = σ x/θ for x < 0
where:
- σ (scale) controls the overall scaling
- θ (theta) controls the asymmetry between positive and negative regions
- μ (loc) controls the location shift
The transformation uses a smooth transition between the two regions to
maintain differentiability.
For θ = 0, this is exactly an affine function with the specified location
and scale.
Attributes:
shape: The shape of the transformation parameters
cond_shape: Shape of conditional inputs (None as this bijection is
unconditional)
loc: Location parameter μ for shifting the distribution
scale: Scale parameter σ (positive)
theta: Asymmetry parameter θ (positive)
"""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
loc: Array
scale: Array | AbstractUnwrappable[Array]
theta: Array | AbstractUnwrappable[Array]

def __init__(
self,
loc: ArrayLike = 0,
scale: ArrayLike = 1,
theta: ArrayLike = 1,
):
self.loc, scale, theta = jnp.broadcast_arrays(
*(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)),
)
self.shape = scale.shape
self.scale = Parameterize(softplus, inv_softplus(scale))
self.theta = Parameterize(softplus, inv_softplus(theta))

def _log_derivative_f(self, x, mu, sigma, theta):
abs_x = jnp.abs(x)
theta = jnp.log(theta)

sinh_theta = jnp.sinh(theta)
#sinh_theta = (theta - 1 / theta) / 2
cosh_theta = jnp.cosh(theta)
#cosh_theta = (theta + 1 / theta) / 2
numerator = sinh_theta * x * (abs_x + 2.0)
denominator = (abs_x + 1.0)**2
term = numerator / denominator
dy_dx = sigma * (cosh_theta + term)
return jnp.log(dy_dx)

def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:

def transform(x, mu, sigma, theta):
weight = (soft_sign(x) + 1) / 2
z = x * sigma
y_pos = z * theta
y_neg = z / theta
y = weight * y_pos + (1.0 - weight) * y_neg + mu
return y

mu, sigma, theta = self.loc, self.scale, self.theta

y = transform(x, mu, sigma, theta)
logjac = self._log_derivative_f(x, mu, sigma, theta)
return y, logjac.sum()

def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:

def inverse(y, mu, sigma, theta):
delta = y - mu
inv_theta = 1 / theta

# Case 1: y >= mu (delta >= 0)
a = sigma * (theta + inv_theta)
discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta
discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos)
sqrt_pos = jnp.sqrt(discriminant_pos)
numerator_pos = 2.0 * delta - a + sqrt_pos
denominator_pos = 4.0 * sigma * theta
x_pos = numerator_pos / denominator_pos

# Case 2: y < mu (delta < 0)
sigma_part = sigma * (1.0 + theta * theta)
term2 = 2.0 * delta * theta
inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta
inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg)
sqrt_neg = jnp.sqrt(inside_sqrt_neg)
numerator_neg = sigma_part + term2 - sqrt_neg
denominator_neg = 4.0 * sigma
x_neg = numerator_neg / denominator_neg

# Combine cases based on delta
x = jnp.where(delta >= 0.0, x_pos, x_neg)
return x

mu, sigma, theta = self.loc, self.scale, self.theta

x = inverse(y, mu, sigma, theta)
logjac = self._log_derivative_f(x, mu, sigma, theta)
return x, -logjac.sum()

30 changes: 18 additions & 12 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,18 +307,24 @@ def inverse_and_log_det(self, y, condition=None):


class Sandwich(AbstractBijection):
"""
A bijection that sandwiches one transformation inside another.
The `Sandwich` bijection applies an "outer" transformation, followed by an
"inner" transformation, and then the inverse of the "outer" transformation.
This allows for the composition of transformations in a nested structure.
Args:
outer (AbstractBijection): The outer transformation applied first and
inverted last.
inner (AbstractBijection): The inner transformation applied between
the forward and inverse outer transformations.
"""A bijection that composes bijections in a nested structure: g⁻¹ ∘ f ∘ g.
The Sandwich bijection creates a new transformation by "sandwiching" one
bijection between the forward and inverse applications of another. Given
bijections f and g, it computes:
Forward: x → g⁻¹(f(g(x)))
Inverse: y → g⁻¹(f⁻¹(g(y)))
This composition pattern is useful for:
- Creating symmetries in the transformation
- Applying a transformation in a different coordinate system
- Building more complex bijections from simpler ones
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs
outer: Transformation g applied first and inverted last
inner: Transformation f applied in the middle
"""
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None
Expand Down
8 changes: 8 additions & 0 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
import numpy as np
from scipy import stats

from flowjax.bijections import (
AbstractBijection,
AdditiveCondition,
Affine,
AsymmetricAffine,
BlockAutoregressiveNetwork,
Chain,
Concatenate,
Expand Down Expand Up @@ -94,6 +97,11 @@
),
jnp.diag(jnp.array([-1, 2, -3])),
),
"AsymmetricAffine": lambda: AsymmetricAffine(
jnp.ones(DIM),
jnp.full(DIM, 2.6),
jnp.full(DIM, 0.1),
),
"RationalQuadraticSpline": lambda: RationalQuadraticSpline(knots=4, interval=1),
"Coupling (unconditional)": lambda: Coupling(
KEY,
Expand Down

0 comments on commit 27482e6

Please sign in to comment.