Skip to content

Commit

Permalink
Merge pull request #142 from danielward27/wrap
Browse files Browse the repository at this point in the history
Wrap
  • Loading branch information
danielward27 authored Mar 12, 2024
2 parents 8948273 + de82680 commit 7ca547e
Show file tree
Hide file tree
Showing 41 changed files with 953 additions and 1,087 deletions.
5 changes: 5 additions & 0 deletions docs/api/wrappers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Wrappers
=============================================
.. automodule:: flowjax.wrappers
:members:
:undoc-members:
38 changes: 8 additions & 30 deletions docs/examples/bounded.ipynb

Large diffs are not rendered by default.

29 changes: 8 additions & 21 deletions docs/examples/conditional.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

20 changes: 6 additions & 14 deletions docs/examples/variational_inference.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ FAQ

Freezing parameters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Often it is useful to not train particular parameters. To achieve this we can provide a
``filter_spec`` to :func:`~flowjax.train.fit_to_data`. For example, to avoid
training the base distribution, we could create a ``filter_spec`` as follows
Often it is useful to not train particular parameters. The easiest way to achieve this
is to use the :class:`flowjax.wrappers.NonTrainable` wrapper class. For example, to
avoid training the base distribution of a transformed distribution:

.. testsetup::

Expand All @@ -15,11 +15,11 @@ training the base distribution, we could create a ``filter_spec`` as follows
.. doctest::

>>> import equinox as eqx
>>> import jax.tree_util as jtu
>>> filter_spec = jtu.tree_map(lambda x: eqx.is_inexact_array(x), flow)
>>> filter_spec = eqx.tree_at(lambda tree: tree.base_dist, filter_spec, replace=False)
>>> from flowjax.wrappers import NonTrainable
>>> flow = eqx.tree_at(lambda flow: flow.base_dist, flow, replace_fn=NonTrainable)

For more information about filtering, see the `equinox documentation <https://docs.kidger.site/equinox/all-of-equinox/>`_.
If you wish to avoid training e.g. a specific type, it may be easier to use
``jax.tree_map`` to apply the NonTrainable wrapper as required.

Standardising variables
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Bijections from ``flowjax.bijections``."""

from .affine import AdditiveCondition, Affine, Scale, TriangularAffine
from .affine import AdditiveCondition, Affine, Loc, Scale, TriangularAffine
from .bijection import AbstractBijection
from .block_autoregressive_network import BlockAutoregressiveNetwork
from .chain import Chain
Expand Down Expand Up @@ -29,6 +29,7 @@
"Identity",
"Invert",
"LeakyTanh",
"Loc",
"MaskedAutoregressive",
"Partial",
"Permute",
Expand Down
199 changes: 59 additions & 140 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,261 +2,180 @@

from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import ClassVar

import equinox as eqx
import jax.numpy as jnp
from jax import Array
from jax.nn import softplus
from jax.scipy.linalg import solve_triangular
from jax.typing import ArrayLike

from flowjax import wrappers
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.softplus import SoftPlus
from flowjax.utils import arraylike_to_array


def _deprecate_positivty(scale_constraint, positivity_constraint):
if positivity_constraint is not None:
if scale_constraint is not None:
raise ValueError(
"Provide only scale_constraint (or diag_constraint for "
"TriangularAffine. positivity_constraint is deprecated.",
)
warnings.warn(
"positivity_constraint has been renamed to scale_contraint (or "
"diag_constraint for TriangularAffine) and will be removed in the next "
"major release.",
)
scale_constraint = positivity_constraint

return scale_constraint


def _argcheck_and_reparam_scale(scale, constraint, error_name: str = "scale"):
scale = eqx.error_if(scale, scale == 0, "Scale must not equal zero.")
_scale = constraint.inverse(scale)
return eqx.error_if(
_scale,
~jnp.isfinite(_scale),
f"Non-finite value(s) in {error_name} when reparameterizing. Check "
f"{error_name} and constraint ({type(constraint).__name__}) are compatible.",
)


class Affine(AbstractBijection):
"""Elementwise affine transformation ``y = a*x + b``.
``loc`` and ``scale`` should broadcast to the desired shape of the bijection.
By default, we constrain the scale parameter to be postive using ``SoftPlus``, but
other parameterizations can be achieved by replacing the scale parameter after
construction e.g. using ``eqx.tree_at``.
Args:
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
scale_constraint: Bijection with shape matching the Affine bijection for
reparameterizing the scale parameter. Defaults to
:class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to scale_constraint.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
_scale: Array
scale_constraint: AbstractBijection
scale: Array | wrappers.AbstractUnwrappable[Array]

def __init__(
self,
loc: ArrayLike = 0,
scale: ArrayLike = 1,
scale_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
):
scale_constraint = _deprecate_positivty(scale_constraint, positivity_constraint)

self.loc, scale = jnp.broadcast_arrays(
*(arraylike_to_array(a, dtype=float) for a in (loc, scale)),
)
self.shape = scale.shape

if scale_constraint is None:
scale_constraint = SoftPlus(self.shape)

self.scale_constraint = scale_constraint
self._scale = _argcheck_and_reparam_scale(scale, scale_constraint)
self.scale = wrappers.BijectionReparam(scale, SoftPlus())

def transform(self, x, condition=None):
return x * self.scale + self.loc

def transform_and_log_det(self, x, condition=None):
scale = self.scale
return x * scale + self.loc, jnp.log(jnp.abs(scale)).sum()
return x * self.scale + self.loc, jnp.log(jnp.abs(self.scale)).sum()

def inverse(self, y, condition=None):
return (y - self.loc) / self.scale

def inverse_and_log_det(self, y, condition=None):
scale = self.scale
return (y - self.loc) / scale, -jnp.log(jnp.abs(scale)).sum()
return (y - self.loc) / self.scale, -jnp.log(jnp.abs(self.scale)).sum()


class Loc(AbstractBijection):
"""Location transformation ``y = x + c``.
Args:
loc: Scale parameter. Defaults to 1.
"""

loc: Array
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None

def __init__(self, loc: ArrayLike):
self.loc = arraylike_to_array(loc)
self.shape = self.loc.shape

def transform(self, x, condition=None):
return x + self.loc

def transform_and_log_det(self, x, condition=None):
return x + self.loc, jnp.zeros(())

def inverse(self, y, condition=None):
return y - self.loc

@property
def scale(self):
"""The scale parameter of the affine transformation."""
return self.scale_constraint.transform(self._scale)
def inverse_and_log_det(self, y, condition=None):
return y - self.loc, jnp.zeros(())


class Scale(AbstractBijection):
"""Scale transformation ``y = a*x``.
Args:
scale: Scale parameter. Defaults to 1.
scale_constraint: Bijection with shape matching the Affine bijection for
reparameterizing the scale parameter. Defaults to
:class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to scale_constraint.
"""

_scale: Array
scale_constraint: AbstractBijection
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
scale: Array | wrappers.AbstractUnwrappable[Array]

def __init__(
self,
scale: ArrayLike,
scale_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
scale: ArrayLike | wrappers.AbstractUnwrappable[Array],
):
scale_constraint = _deprecate_positivty(scale_constraint, positivity_constraint)

if scale_constraint is None:
scale_constraint = SoftPlus(jnp.shape(scale))
self.shape = jnp.shape(scale)
self.scale_constraint = scale_constraint
self._scale = _argcheck_and_reparam_scale(scale, scale_constraint)
self.scale = wrappers.BijectionReparam(scale, SoftPlus())
self.shape = jnp.shape(wrappers.unwrap(scale))

def transform(self, x, condition=None):
return x * self.scale

def transform_and_log_det(self, x, condition=None):
scale = self.scale
return x * scale, jnp.log(jnp.abs(scale)).sum()
return x * self.scale, jnp.log(jnp.abs(self.scale)).sum()

def inverse(self, y, condition=None):
return y / self.scale

def inverse_and_log_det(self, y, condition=None):
scale = self.scale
return y / scale, -jnp.log(jnp.abs(scale)).sum()

@property
def scale(self):
"""The scale parameter of the affine transformation."""
return self.scale_constraint.transform(self._scale)
return y / self.scale, -jnp.log(jnp.abs(self.scale)).sum()


class TriangularAffine(AbstractBijection):
r"""A triangular affine transformation.
Transformation has the form :math:`Ax + b`, where :math:`A` is a lower or upper
triangular matrix, and :math:`b` is the bias vector.
triangular matrix, and :math:`b` is the bias vector. We assume the diagonal
entries are positive, and constrain the values using SoftPlus. Other
parameterizations can be achieved by e.g. replacing ``self.triangular``
after construction.
Args:
loc: Location parameter. If this is scalar, it is broadcast to the dimension
inferred from arr.
arr: Triangular matrix.
lower: Whether the mask should select the lower or upper
triangular matrix (other elements ignored). Defaults to True (lower).
weight_normalisation: If true, carry out weight normalisation.
diag_constraint: Bijection with shape matching diag(arr) for reparameterizing
the diagonal elements. Defaults to :class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to diag_constraint.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
diag_idxs: Array
tri_mask: Array
triangular: Array | wrappers.AbstractUnwrappable[Array]
lower: bool
diag_constraint: AbstractBijection
_arr: Array
_diag: Array
_weight_scale: Array | None

def __init__(
self,
loc: ArrayLike,
arr: ArrayLike,
*,
lower: bool = True,
weight_normalisation: bool = False,
diag_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
):
diag_constraint = _deprecate_positivty(diag_constraint, positivity_constraint)
loc, arr = (arraylike_to_array(a, dtype=float) for a in (loc, arr))
if (arr.ndim != 2) or (arr.shape[0] != arr.shape[1]):
raise ValueError("arr must be a square, 2-dimensional matrix.")

dim = arr.shape[0]
self.diag_idxs = jnp.diag_indices(dim)
tri_mask = jnp.tril(jnp.ones((dim, dim), dtype=jnp.int32), k=-1)
self.tri_mask = tri_mask if lower else tri_mask.T
self.lower = lower

self.shape = (dim,)

if diag_constraint is None:
diag_constraint = SoftPlus(self.shape)

self.diag_constraint = diag_constraint
def _to_triangular(diag, arr):
tri = jnp.tril(arr, k=-1) if lower else jnp.triu(arr, k=1)
return jnp.diag(diag) + tri

# inexact arrays
diag = wrappers.BijectionReparam(jnp.diag(arr), SoftPlus())
self.triangular = wrappers.Lambda(_to_triangular, diag=diag, arr=arr)
self.lower = lower
self.shape = (dim,)
self.loc = jnp.broadcast_to(loc, (dim,))
self._arr = arr
self._diag = _argcheck_and_reparam_scale(
jnp.diag(arr), diag_constraint, "diagonal values"
)

if weight_normalisation:
self._weight_scale = jnp.full(dim, SoftPlus().inverse(1))
else:
self._weight_scale = None

@property
def arr(self):
"""Get the triangular array.
Applies masking, constrains the diagonal to be positive and (possibly)
applies weight normalisation.
"""
diag = self.diag_constraint.transform(self._diag)
off_diag = self.tri_mask * self._arr
arr = off_diag.at[self.diag_idxs].set(diag)

if self._weight_scale is not None:
norms = jnp.linalg.norm(arr, axis=1, keepdims=True)
scale = softplus(self._weight_scale)[:, None]
arr = scale * arr / norms

return arr

def transform(self, x, condition=None):
return self.arr @ x + self.loc
return self.triangular @ x + self.loc

def transform_and_log_det(self, x, condition=None):
arr = self.arr
return arr @ x + self.loc, jnp.log(jnp.abs(jnp.diag(arr))).sum()
y = self.triangular @ x + self.loc
return y, jnp.log(jnp.abs(jnp.diag(self.triangular))).sum()

def inverse(self, y, condition=None):
return solve_triangular(self.arr, y - self.loc, lower=self.lower)
return solve_triangular(self.triangular, y - self.loc, lower=self.lower)

def inverse_and_log_det(self, y, condition=None):
arr = self.arr
x = solve_triangular(arr, y - self.loc, lower=self.lower)
return x, -jnp.log(jnp.abs(jnp.diag(arr))).sum()
x = solve_triangular(self.triangular, y - self.loc, lower=self.lower)
return x, -jnp.log(jnp.abs(jnp.diag(self.triangular))).sum()


class AdditiveCondition(AbstractBijection):
Expand Down
Loading

0 comments on commit 7ca547e

Please sign in to comment.