Skip to content

Commit

Permalink
Remove Neg and AssymetricAffine and minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Feb 14, 2025
1 parent 328de18 commit 1002c9a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 206 deletions.
8 changes: 3 additions & 5 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .power import Power
from .rational_quadratic_spline import RationalQuadraticSpline
from .sigmoid import Sigmoid
from .softplus import SoftPlus, AsymmetricAffine
from .softplus import SoftPlus
from .tanh import LeakyTanh, Tanh
from .utils import (
EmbedCondition,
Expand All @@ -27,18 +27,17 @@
Sandwich,
)
from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich
from .orthogonal import Householder, DCT, Neg
from .orthogonal import Householder, DiscreteCosine

__all__ = [
"AdditiveCondition",
"Affine",
"AbstractBijection",
"AsymmetricAffine",
"BlockAutoregressiveNetwork",
"Chain",
"Concatenate",
"Coupling",
"DCT",
"DiscreteCosine",
"EmbedCondition",
"Exp",
"Flip",
Expand All @@ -49,7 +48,6 @@
"Loc",
"MaskedAutoregressive",
"Indexed",
"Neg",
"Permute",
"Power",
"Planar",
Expand Down
66 changes: 10 additions & 56 deletions flowjax/bijections/orthogonal.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,11 @@
from paramax import AbstractUnwrappable, Parameterize
from flowjax.bijections.bijection import AbstractBijection
from jax import Array
import jax.numpy as jnp
import jax.nn as jnn
from jax.scipy import fft


class Neg(AbstractBijection):
"""A bijection that negates its input (multiplies by -1).
This is a simple bijection that flips the sign of all elements in the input array.
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
"""
shape: tuple[int, ...]
cond_shape = None

def __init__(self, shape):
self.shape = shape

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
return -x, jnp.zeros(())

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
return -y, jnp.zeros(())


class Householder(AbstractBijection):
"""A Householder reflection bijection.
Expand All @@ -46,39 +25,24 @@ class Householder(AbstractBijection):
on the bijection.
"""
shape: tuple[int, ...]
params: Array
unit_vec: Array | AbstractUnwrappable
cond_shape = None

def __init__(self, params: Array):
self.shape = (params.shape[-1],)
self.params = params

def _householder(self, x: Array, params: Array) -> Array:
norm_sq = params @ params
norm = jnp.sqrt(norm_sq)
self.unit_vec = Parameterize(lambda x: x / jnp.linalg.norm(x), params)

vec = params / norm
return x - 2 * vec * (x @ vec)
def _householder(self, x: Array) -> Array:
return x - 2 * self.unit_vec * (x @ self.unit_vec)

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
return self._householder(x, self.params), jnp.zeros(())
return self._householder(x), jnp.zeros(())

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
return self._householder(y, self.params), jnp.zeros(())

def inverse_gradient_and_val(
self,
y: Array,
y_grad: Array,
y_logp: Array,
condition: Array | None = None,
) -> tuple[Array, Array, Array]:
x, logdet = self.inverse_and_log_det(y)
x_grad = self._householder(y_grad, params=self.params)
return (x, x_grad, y_logp - logdet)
return self._householder(y), jnp.zeros(())


class DCT(AbstractBijection):
class DiscreteCosine(AbstractBijection):
"""Discrete Cosine Transform (DCT) bijection.
This bijection applies the DCT or its inverse along a specified axis.
Expand All @@ -93,25 +57,15 @@ class DCT(AbstractBijection):
shape: tuple[int, ...]
cond_shape = None
axis: int
norm: str

def __init__(self, shape, *, axis: int = -1):
self.shape = shape
self.axis = axis
self.norm = "ortho"

def _dct(self, x: Array, inverse: bool = False) -> Array:
if inverse:
z = fft.idct(x, norm=self.norm, axis=self.axis)
else:
z = fft.dct(x, norm=self.norm, axis=self.axis)

return z

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
y = self._dct(x)
y = fft.dct(x, norm="ortho", axis=self.axis)
return y, jnp.zeros(())

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
x = self._dct(y, inverse=True)
x = fft.idct(y, norm="ortho", axis=self.axis)
return x, jnp.zeros(())
117 changes: 0 additions & 117 deletions flowjax/bijections/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,120 +24,3 @@ def transform_and_log_det(self, x, condition=None):
def inverse_and_log_det(self, y, condition=None):
x = jnp.log(-jnp.expm1(-y)) + y
return x, softplus(-x).sum()


class AsymmetricAffine(AbstractBijection):
"""An asymmetric bijection that applies different scaling factors for
positive and negative inputs.
This bijection implements a continuous, differentiable transformation that
scales positive and negative inputs differently while maintaining smoothness
at zero. It's particularly useful for modeling data with different variances
in positive and negative regions.
The forward transformation is defined as:
y = σ θ x for x ≥ 0
y = σ x/θ for x < 0
where:
- σ (scale) controls the overall scaling
- θ (theta) controls the asymmetry between positive and negative regions
- μ (loc) controls the location shift
The transformation uses a smooth transition between the two regions to
maintain differentiability.
For θ = 0, this is exactly an affine function with the specified location
and scale.
Attributes:
shape: The shape of the transformation parameters
cond_shape: Shape of conditional inputs (None as this bijection is
unconditional)
loc: Location parameter μ for shifting the distribution
scale: Scale parameter σ (positive)
theta: Asymmetry parameter θ (positive)
"""
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None
loc: Array
scale: Array | AbstractUnwrappable[Array]
theta: Array | AbstractUnwrappable[Array]

def __init__(
self,
loc: ArrayLike = 0,
scale: ArrayLike = 1,
theta: ArrayLike = 1,
):
self.loc, scale, theta = jnp.broadcast_arrays(
*(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)),
)
self.shape = scale.shape
self.scale = Parameterize(softplus, inv_softplus(scale))
self.theta = Parameterize(softplus, inv_softplus(theta))

def _log_derivative_f(self, x, mu, sigma, theta):
abs_x = jnp.abs(x)
theta = jnp.log(theta)

sinh_theta = jnp.sinh(theta)
#sinh_theta = (theta - 1 / theta) / 2
cosh_theta = jnp.cosh(theta)
#cosh_theta = (theta + 1 / theta) / 2
numerator = sinh_theta * x * (abs_x + 2.0)
denominator = (abs_x + 1.0)**2
term = numerator / denominator
dy_dx = sigma * (cosh_theta + term)
return jnp.log(dy_dx)

def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:

def transform(x, mu, sigma, theta):
weight = (soft_sign(x) + 1) / 2
z = x * sigma
y_pos = z * theta
y_neg = z / theta
y = weight * y_pos + (1.0 - weight) * y_neg + mu
return y

mu, sigma, theta = self.loc, self.scale, self.theta

y = transform(x, mu, sigma, theta)
logjac = self._log_derivative_f(x, mu, sigma, theta)
return y, logjac.sum()

def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]:

def inverse(y, mu, sigma, theta):
delta = y - mu
inv_theta = 1 / theta

# Case 1: y >= mu (delta >= 0)
a = sigma * (theta + inv_theta)
discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta
discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos)
sqrt_pos = jnp.sqrt(discriminant_pos)
numerator_pos = 2.0 * delta - a + sqrt_pos
denominator_pos = 4.0 * sigma * theta
x_pos = numerator_pos / denominator_pos

# Case 2: y < mu (delta < 0)
sigma_part = sigma * (1.0 + theta * theta)
term2 = 2.0 * delta * theta
inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta
inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg)
sqrt_neg = jnp.sqrt(inside_sqrt_neg)
numerator_neg = sigma_part + term2 - sqrt_neg
denominator_neg = 4.0 * sigma
x_neg = numerator_neg / denominator_neg

# Combine cases based on delta
x = jnp.where(delta >= 0.0, x_pos, x_neg)
return x

mu, sigma, theta = self.loc, self.scale, self.theta

x = inverse(y, mu, sigma, theta)
logjac = self._log_derivative_f(x, mu, sigma, theta)
return x, -logjac.sum()

27 changes: 9 additions & 18 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from jaxtyping import Array, Int

from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import arraylike_to_array
from flowjax.bijections.chain import Chain
from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes


class Invert(AbstractBijection):
Expand Down Expand Up @@ -332,26 +333,16 @@ class Sandwich(AbstractBijection):
inner: AbstractBijection

def __init__(self, outer: AbstractBijection, inner: AbstractBijection):
shape = inner.shape
if outer.shape != shape:
raise ValueError("Inner and outer transformations are incompatible")
self.cond_shape = inner.cond_shape
if outer.cond_shape != self.cond_shape:
raise ValueError("Inner and outer transformations are incompatible")
self.shape = shape
check_shapes_match([outer.shape, inner.shape])
self.cond_shape = merge_cond_shapes([outer.cond_shape, inner.cond_shape])
self.shape = inner.shape
self.outer = outer
self.inner = inner

def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]:
z1, logdet1 = self.outer.transform_and_log_det(x, condition)
z2, logdet2 = self.inner.transform_and_log_det(z1, condition)
y, logdet3 = self.outer.inverse_and_log_det(z2, condition)

return y, logdet1 + logdet2 + logdet3
chain = Chain([self.outer, self.inner, Invert(self.outer)])
return chain.transform_and_log_det(x, condition)

def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]:
z1, logdet1 = self.outer.transform_and_log_det(y, condition)
z2, logdet2 = self.inner.inverse_and_log_det(z1, condition)
x, logdet3 = self.outer.inverse_and_log_det(z2, condition)

return x, logdet1 + logdet2 + logdet3
chain = Chain([self.outer, self.inner, Invert(self.outer)])
return chain.inverse_and_log_det(y, condition)
12 changes: 2 additions & 10 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
AbstractBijection,
AdditiveCondition,
Affine,
AsymmetricAffine,
BlockAutoregressiveNetwork,
Chain,
Concatenate,
Coupling,
DCT,
DiscreteCosine,
EmbedCondition,
Exp,
Flip,
Expand All @@ -30,7 +29,6 @@
Loc,
MaskedAutoregressive,
NumericalInverse,
Neg,
Permute,
Planar,
Power,
Expand Down Expand Up @@ -97,11 +95,6 @@
),
jnp.diag(jnp.array([-1, 2, -3])),
),
"AsymmetricAffine": lambda: AsymmetricAffine(
jnp.ones(DIM),
jnp.full(DIM, 2.6),
jnp.full(DIM, 0.1),
),
"RationalQuadraticSpline": lambda: RationalQuadraticSpline(knots=4, interval=1),
"Coupling (unconditional)": lambda: Coupling(
KEY,
Expand Down Expand Up @@ -144,7 +137,6 @@
nn_depth=2,
)
),
"Neg": lambda: Neg(shape=(DIM,)),
"BlockAutoregressiveNetwork (unconditional)": lambda: BlockAutoregressiveNetwork(
KEY,
dim=DIM,
Expand Down Expand Up @@ -234,7 +226,7 @@
Exp(),
Affine(0.1, 0.5),
),
"DCT": lambda: DCT(shape=(3, 4)),
"DiscreteCosine": lambda: DiscreteCosine(shape=(3, 4)),
"Householder": lambda: Householder(jnp.ones(3)),
}

Expand Down

0 comments on commit 1002c9a

Please sign in to comment.