diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index 6f511397..d507037b 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -1,4 +1,5 @@ """Affine bijections.""" + from __future__ import annotations from collections.abc import Callable @@ -137,6 +138,7 @@ class TriangularAffine(AbstractBijection): from an unbounded domain to the positive domain. Also used for weight normalisation parameters, if used. Defaults to SoftPlus. """ + shape: tuple[int, ...] cond_shape: ClassVar[None] = None loc: Array diff --git a/flowjax/flows.py b/flowjax/flows.py index aa0af633..acc12ace 100644 --- a/flowjax/flows.py +++ b/flowjax/flows.py @@ -2,12 +2,14 @@ All these functions return a :class:`~flowjax.distributions.Transformed` distribution. """ + # Note that here although we could chain arbitrary bijections using `Chain`, here, # we generally opt to use `Scan`, which avoids excessive compilation # when the flow layers share the same structure. from collections.abc import Callable from functools import partial +from typing import ClassVar import equinox as eqx import jax.nn as jnn @@ -32,6 +34,7 @@ Planar, RationalQuadraticSpline, Scan, + SoftPlus, TriangularAffine, Vmap, ) @@ -56,7 +59,7 @@ def coupling_flow( key: Jax random number generator key. base_dist: Base distribution, with ``base_dist.ndim==1``. transformer: Bijection to be parameterised by conditioner. Defaults to - ``Affine()``. + affine. cond_dim: Dimension of conditioning variables. Defaults to None. flow_layers: Number of coupling layers. Defaults to 8. nn_width: Conditioner hidden layer size. Defaults to 50. @@ -66,7 +69,11 @@ def coupling_flow( `inverse` methods, leading to faster `log_prob`, False will prioritise faster `transform` methods, leading to faster `sample`. Defaults to True. """ - transformer = Affine() if transformer is None else transformer + if transformer is None: + transformer = Affine( + positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]), + ) + dim = base_dist.shape[-1] def make_layer(key): # coupling layer + permutation @@ -110,7 +117,7 @@ def masked_autoregressive_flow( key: Random seed. base_dist: Base distribution, with ``base_dist.ndim==1``. transformer: Bijection parameterised by autoregressive network. Defaults to - ``Affine()``. + affine. cond_dim: Dimension of the conditioning variable. Defaults to None. flow_layers: Number of flow layers. Defaults to 8. nn_width: Number of hidden layers in neural network. Defaults to 50. @@ -120,7 +127,10 @@ def masked_autoregressive_flow( inverse, leading to faster `log_prob`, False will prioritise faster forward, leading to faster `sample`. Defaults to True. """ - transformer = Affine() if transformer is None else transformer + if transformer is None: + transformer = Affine( + positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]), + ) dim = base_dist.shape[-1] def make_layer(key): # masked autoregressive layer + permutation @@ -326,3 +336,25 @@ def _add_default_permute(bijection: AbstractBijection, dim: int, key: Array): perm = Permute(jr.permutation(key, jnp.arange(dim))) return Chain([bijection, perm]).merge_chains() + + +class _PlusConst(AbstractBijection): + """Adds a constant.""" + + # We use this to add a small constant to the affine scale parameter + # which seems to improve stabillity masked autoregressive and coupling flows + const: float + shape: tuple[int, ...] = () + cond_shape: ClassVar[None] = None + + def transform(self, x, condition=None): + return x + self.const + + def transform_and_log_det(self, x, condition=None): + return x + self.const, jnp.array(0) + + def inverse_and_log_det(self, y, condition=None): + return y - self.const, jnp.array(0) + + def inverse(self, y, condition=None): + return y - self.const