From 1002c9a5e7e92da102ca4cb713be7a801cef9810 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Feb 2025 19:04:06 +0100 Subject: [PATCH] Remove Neg and AssymetricAffine and minor edits --- flowjax/bijections/__init__.py | 8 +- flowjax/bijections/orthogonal.py | 66 ++----------- flowjax/bijections/softplus.py | 117 ----------------------- flowjax/bijections/utils.py | 27 ++---- tests/test_bijections/test_bijections.py | 12 +-- 5 files changed, 24 insertions(+), 206 deletions(-) diff --git a/flowjax/bijections/__init__.py b/flowjax/bijections/__init__.py index 4f69851..8fe2bcf 100644 --- a/flowjax/bijections/__init__.py +++ b/flowjax/bijections/__init__.py @@ -13,7 +13,7 @@ from .power import Power from .rational_quadratic_spline import RationalQuadraticSpline from .sigmoid import Sigmoid -from .softplus import SoftPlus, AsymmetricAffine +from .softplus import SoftPlus from .tanh import LeakyTanh, Tanh from .utils import ( EmbedCondition, @@ -27,18 +27,17 @@ Sandwich, ) from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich -from .orthogonal import Householder, DCT, Neg +from .orthogonal import Householder, DiscreteCosine __all__ = [ "AdditiveCondition", "Affine", "AbstractBijection", - "AsymmetricAffine", "BlockAutoregressiveNetwork", "Chain", "Concatenate", "Coupling", - "DCT", + "DiscreteCosine", "EmbedCondition", "Exp", "Flip", @@ -49,7 +48,6 @@ "Loc", "MaskedAutoregressive", "Indexed", - "Neg", "Permute", "Power", "Planar", diff --git a/flowjax/bijections/orthogonal.py b/flowjax/bijections/orthogonal.py index d79a6be..4d06534 100644 --- a/flowjax/bijections/orthogonal.py +++ b/flowjax/bijections/orthogonal.py @@ -1,3 +1,4 @@ +from paramax import AbstractUnwrappable, Parameterize from flowjax.bijections.bijection import AbstractBijection from jax import Array import jax.numpy as jnp @@ -5,28 +6,6 @@ from jax.scipy import fft -class Neg(AbstractBijection): - """A bijection that negates its input (multiplies by -1). - - This is a simple bijection that flips the sign of all elements in the input array. - - Attributes: - shape: Shape of the input/output arrays - cond_shape: Shape of conditional inputs (None as this bijection is unconditional) - """ - shape: tuple[int, ...] - cond_shape = None - - def __init__(self, shape): - self.shape = shape - - def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - return -x, jnp.zeros(()) - - def inverse_and_log_det(self, y: Array, condition: Array | None = None): - return -y, jnp.zeros(()) - - class Householder(AbstractBijection): """A Householder reflection bijection. @@ -46,39 +25,24 @@ class Householder(AbstractBijection): on the bijection. """ shape: tuple[int, ...] - params: Array + unit_vec: Array | AbstractUnwrappable cond_shape = None def __init__(self, params: Array): self.shape = (params.shape[-1],) - self.params = params - - def _householder(self, x: Array, params: Array) -> Array: - norm_sq = params @ params - norm = jnp.sqrt(norm_sq) + self.unit_vec = Parameterize(lambda x: x / jnp.linalg.norm(x), params) - vec = params / norm - return x - 2 * vec * (x @ vec) + def _householder(self, x: Array) -> Array: + return x - 2 * self.unit_vec * (x @ self.unit_vec) def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - return self._householder(x, self.params), jnp.zeros(()) + return self._householder(x), jnp.zeros(()) def inverse_and_log_det(self, y: Array, condition: Array | None = None): - return self._householder(y, self.params), jnp.zeros(()) - - def inverse_gradient_and_val( - self, - y: Array, - y_grad: Array, - y_logp: Array, - condition: Array | None = None, - ) -> tuple[Array, Array, Array]: - x, logdet = self.inverse_and_log_det(y) - x_grad = self._householder(y_grad, params=self.params) - return (x, x_grad, y_logp - logdet) + return self._householder(y), jnp.zeros(()) -class DCT(AbstractBijection): +class DiscreteCosine(AbstractBijection): """Discrete Cosine Transform (DCT) bijection. This bijection applies the DCT or its inverse along a specified axis. @@ -93,25 +57,15 @@ class DCT(AbstractBijection): shape: tuple[int, ...] cond_shape = None axis: int - norm: str def __init__(self, shape, *, axis: int = -1): self.shape = shape self.axis = axis - self.norm = "ortho" - - def _dct(self, x: Array, inverse: bool = False) -> Array: - if inverse: - z = fft.idct(x, norm=self.norm, axis=self.axis) - else: - z = fft.dct(x, norm=self.norm, axis=self.axis) - - return z def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - y = self._dct(x) + y = fft.dct(x, norm="ortho", axis=self.axis) return y, jnp.zeros(()) def inverse_and_log_det(self, y: Array, condition: Array | None = None): - x = self._dct(y, inverse=True) + x = fft.idct(y, norm="ortho", axis=self.axis) return x, jnp.zeros(()) diff --git a/flowjax/bijections/softplus.py b/flowjax/bijections/softplus.py index c45b903..90c3315 100644 --- a/flowjax/bijections/softplus.py +++ b/flowjax/bijections/softplus.py @@ -24,120 +24,3 @@ def transform_and_log_det(self, x, condition=None): def inverse_and_log_det(self, y, condition=None): x = jnp.log(-jnp.expm1(-y)) + y return x, softplus(-x).sum() - - -class AsymmetricAffine(AbstractBijection): - """An asymmetric bijection that applies different scaling factors for - positive and negative inputs. - - This bijection implements a continuous, differentiable transformation that - scales positive and negative inputs differently while maintaining smoothness - at zero. It's particularly useful for modeling data with different variances - in positive and negative regions. - - The forward transformation is defined as: - y = σ θ x for x ≥ 0 - y = σ x/θ for x < 0 - where: - - σ (scale) controls the overall scaling - - θ (theta) controls the asymmetry between positive and negative regions - - μ (loc) controls the location shift - - The transformation uses a smooth transition between the two regions to - maintain differentiability. - - For θ = 0, this is exactly an affine function with the specified location - and scale. - - Attributes: - shape: The shape of the transformation parameters - cond_shape: Shape of conditional inputs (None as this bijection is - unconditional) - loc: Location parameter μ for shifting the distribution - scale: Scale parameter σ (positive) - theta: Asymmetry parameter θ (positive) - """ - shape: tuple[int, ...] = () - cond_shape: ClassVar[None] = None - loc: Array - scale: Array | AbstractUnwrappable[Array] - theta: Array | AbstractUnwrappable[Array] - - def __init__( - self, - loc: ArrayLike = 0, - scale: ArrayLike = 1, - theta: ArrayLike = 1, - ): - self.loc, scale, theta = jnp.broadcast_arrays( - *(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)), - ) - self.shape = scale.shape - self.scale = Parameterize(softplus, inv_softplus(scale)) - self.theta = Parameterize(softplus, inv_softplus(theta)) - - def _log_derivative_f(self, x, mu, sigma, theta): - abs_x = jnp.abs(x) - theta = jnp.log(theta) - - sinh_theta = jnp.sinh(theta) - #sinh_theta = (theta - 1 / theta) / 2 - cosh_theta = jnp.cosh(theta) - #cosh_theta = (theta + 1 / theta) / 2 - numerator = sinh_theta * x * (abs_x + 2.0) - denominator = (abs_x + 1.0)**2 - term = numerator / denominator - dy_dx = sigma * (cosh_theta + term) - return jnp.log(dy_dx) - - def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: - - def transform(x, mu, sigma, theta): - weight = (soft_sign(x) + 1) / 2 - z = x * sigma - y_pos = z * theta - y_neg = z / theta - y = weight * y_pos + (1.0 - weight) * y_neg + mu - return y - - mu, sigma, theta = self.loc, self.scale, self.theta - - y = transform(x, mu, sigma, theta) - logjac = self._log_derivative_f(x, mu, sigma, theta) - return y, logjac.sum() - - def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: - - def inverse(y, mu, sigma, theta): - delta = y - mu - inv_theta = 1 / theta - - # Case 1: y >= mu (delta >= 0) - a = sigma * (theta + inv_theta) - discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta - discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos) - sqrt_pos = jnp.sqrt(discriminant_pos) - numerator_pos = 2.0 * delta - a + sqrt_pos - denominator_pos = 4.0 * sigma * theta - x_pos = numerator_pos / denominator_pos - - # Case 2: y < mu (delta < 0) - sigma_part = sigma * (1.0 + theta * theta) - term2 = 2.0 * delta * theta - inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta - inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg) - sqrt_neg = jnp.sqrt(inside_sqrt_neg) - numerator_neg = sigma_part + term2 - sqrt_neg - denominator_neg = 4.0 * sigma - x_neg = numerator_neg / denominator_neg - - # Combine cases based on delta - x = jnp.where(delta >= 0.0, x_pos, x_neg) - return x - - mu, sigma, theta = self.loc, self.scale, self.theta - - x = inverse(y, mu, sigma, theta) - logjac = self._log_derivative_f(x, mu, sigma, theta) - return x, -logjac.sum() - diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 1ff766e..3654937 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -10,7 +10,8 @@ from jaxtyping import Array, Int from flowjax.bijections.bijection import AbstractBijection -from flowjax.utils import arraylike_to_array +from flowjax.bijections.chain import Chain +from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes class Invert(AbstractBijection): @@ -332,26 +333,16 @@ class Sandwich(AbstractBijection): inner: AbstractBijection def __init__(self, outer: AbstractBijection, inner: AbstractBijection): - shape = inner.shape - if outer.shape != shape: - raise ValueError("Inner and outer transformations are incompatible") - self.cond_shape = inner.cond_shape - if outer.cond_shape != self.cond_shape: - raise ValueError("Inner and outer transformations are incompatible") - self.shape = shape + check_shapes_match([outer.shape, inner.shape]) + self.cond_shape = merge_cond_shapes([outer.cond_shape, inner.cond_shape]) + self.shape = inner.shape self.outer = outer self.inner = inner def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]: - z1, logdet1 = self.outer.transform_and_log_det(x, condition) - z2, logdet2 = self.inner.transform_and_log_det(z1, condition) - y, logdet3 = self.outer.inverse_and_log_det(z2, condition) - - return y, logdet1 + logdet2 + logdet3 + chain = Chain([self.outer, self.inner, Invert(self.outer)]) + return chain.transform_and_log_det(x, condition) def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]: - z1, logdet1 = self.outer.transform_and_log_det(y, condition) - z2, logdet2 = self.inner.inverse_and_log_det(z1, condition) - x, logdet3 = self.outer.inverse_and_log_det(z2, condition) - - return x, logdet1 + logdet2 + logdet3 + chain = Chain([self.outer, self.inner, Invert(self.outer)]) + return chain.inverse_and_log_det(y, condition) diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 380c766..4ff1561 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -14,12 +14,11 @@ AbstractBijection, AdditiveCondition, Affine, - AsymmetricAffine, BlockAutoregressiveNetwork, Chain, Concatenate, Coupling, - DCT, + DiscreteCosine, EmbedCondition, Exp, Flip, @@ -30,7 +29,6 @@ Loc, MaskedAutoregressive, NumericalInverse, - Neg, Permute, Planar, Power, @@ -97,11 +95,6 @@ ), jnp.diag(jnp.array([-1, 2, -3])), ), - "AsymmetricAffine": lambda: AsymmetricAffine( - jnp.ones(DIM), - jnp.full(DIM, 2.6), - jnp.full(DIM, 0.1), - ), "RationalQuadraticSpline": lambda: RationalQuadraticSpline(knots=4, interval=1), "Coupling (unconditional)": lambda: Coupling( KEY, @@ -144,7 +137,6 @@ nn_depth=2, ) ), - "Neg": lambda: Neg(shape=(DIM,)), "BlockAutoregressiveNetwork (unconditional)": lambda: BlockAutoregressiveNetwork( KEY, dim=DIM, @@ -234,7 +226,7 @@ Exp(), Affine(0.1, 0.5), ), - "DCT": lambda: DCT(shape=(3, 4)), + "DiscreteCosine": lambda: DiscreteCosine(shape=(3, 4)), "Householder": lambda: Householder(jnp.ones(3)), }