Skip to content

Commit

Permalink
Merge pull request #140 from danielward27/min_scale_affine_transformer
Browse files Browse the repository at this point in the history
min scale for affine transformer
  • Loading branch information
danielward27 authored Feb 29, 2024
2 parents 0818696 + 4f61548 commit 65fcdf1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
2 changes: 2 additions & 0 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Affine bijections."""

from __future__ import annotations

from collections.abc import Callable
Expand Down Expand Up @@ -137,6 +138,7 @@ class TriangularAffine(AbstractBijection):
from an unbounded domain to the positive domain. Also used for weight
normalisation parameters, if used. Defaults to SoftPlus.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
Expand Down
40 changes: 36 additions & 4 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
All these functions return a :class:`~flowjax.distributions.Transformed` distribution.
"""

# Note that here although we could chain arbitrary bijections using `Chain`, here,
# we generally opt to use `Scan`, which avoids excessive compilation
# when the flow layers share the same structure.

from collections.abc import Callable
from functools import partial
from typing import ClassVar

import equinox as eqx
import jax.nn as jnn
Expand All @@ -32,6 +34,7 @@
Planar,
RationalQuadraticSpline,
Scan,
SoftPlus,
TriangularAffine,
Vmap,
)
Expand All @@ -56,7 +59,7 @@ def coupling_flow(
key: Jax random number generator key.
base_dist: Base distribution, with ``base_dist.ndim==1``.
transformer: Bijection to be parameterised by conditioner. Defaults to
``Affine()``.
affine.
cond_dim: Dimension of conditioning variables. Defaults to None.
flow_layers: Number of coupling layers. Defaults to 8.
nn_width: Conditioner hidden layer size. Defaults to 50.
Expand All @@ -66,7 +69,11 @@ def coupling_flow(
`inverse` methods, leading to faster `log_prob`, False will prioritise
faster `transform` methods, leading to faster `sample`. Defaults to True.
"""
transformer = Affine() if transformer is None else transformer
if transformer is None:
transformer = Affine(
positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]),
)

dim = base_dist.shape[-1]

def make_layer(key): # coupling layer + permutation
Expand Down Expand Up @@ -110,7 +117,7 @@ def masked_autoregressive_flow(
key: Random seed.
base_dist: Base distribution, with ``base_dist.ndim==1``.
transformer: Bijection parameterised by autoregressive network. Defaults to
``Affine()``.
affine.
cond_dim: Dimension of the conditioning variable. Defaults to None.
flow_layers: Number of flow layers. Defaults to 8.
nn_width: Number of hidden layers in neural network. Defaults to 50.
Expand All @@ -120,7 +127,10 @@ def masked_autoregressive_flow(
inverse, leading to faster `log_prob`, False will prioritise faster forward,
leading to faster `sample`. Defaults to True.
"""
transformer = Affine() if transformer is None else transformer
if transformer is None:
transformer = Affine(
positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]),
)
dim = base_dist.shape[-1]

def make_layer(key): # masked autoregressive layer + permutation
Expand Down Expand Up @@ -326,3 +336,25 @@ def _add_default_permute(bijection: AbstractBijection, dim: int, key: Array):

perm = Permute(jr.permutation(key, jnp.arange(dim)))
return Chain([bijection, perm]).merge_chains()


class _PlusConst(AbstractBijection):
"""Adds a constant."""

# We use this to add a small constant to the affine scale parameter
# which seems to improve stabillity masked autoregressive and coupling flows
const: float
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def transform(self, x, condition=None):
return x + self.const

def transform_and_log_det(self, x, condition=None):
return x + self.const, jnp.array(0)

def inverse_and_log_det(self, y, condition=None):
return y - self.const, jnp.array(0)

def inverse(self, y, condition=None):
return y - self.const

0 comments on commit 65fcdf1

Please sign in to comment.