Skip to content

Commit

Permalink
Merge pull request #141 from danielward27/negative_scales
Browse files Browse the repository at this point in the history
Negative scales
  • Loading branch information
danielward27 authored Mar 1, 2024
2 parents 65fcdf1 + 38f3a56 commit 8948273
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 68 deletions.
132 changes: 84 additions & 48 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import ClassVar

import equinox as eqx
import jax.numpy as jnp
from jax import Array
from jax.experimental import checkify
from jax.nn import softplus
from jax.scipy.linalg import solve_triangular
from jax.typing import ArrayLike

Expand All @@ -16,6 +18,34 @@
from flowjax.utils import arraylike_to_array


def _deprecate_positivty(scale_constraint, positivity_constraint):
if positivity_constraint is not None:
if scale_constraint is not None:
raise ValueError(
"Provide only scale_constraint (or diag_constraint for "
"TriangularAffine. positivity_constraint is deprecated.",
)
warnings.warn(
"positivity_constraint has been renamed to scale_contraint (or "
"diag_constraint for TriangularAffine) and will be removed in the next "
"major release.",
)
scale_constraint = positivity_constraint

return scale_constraint


def _argcheck_and_reparam_scale(scale, constraint, error_name: str = "scale"):
scale = eqx.error_if(scale, scale == 0, "Scale must not equal zero.")
_scale = constraint.inverse(scale)
return eqx.error_if(
_scale,
~jnp.isfinite(_scale),
f"Non-finite value(s) in {error_name} when reparameterizing. Check "
f"{error_name} and constraint ({type(constraint).__name__}) are compatible.",
)


class Affine(AbstractBijection):
"""Elementwise affine transformation ``y = a*x + b``.
Expand All @@ -24,100 +54,106 @@ class Affine(AbstractBijection):
Args:
loc: Location parameter. Defaults to 0.
scale: Scale parameter. Defaults to 1.
positivity_constraint: Bijection with shape matching the Affine bijection, that
maps the scale parameter from an unbounded domain to the positive domain.
Defaults to :class:`~flowjax.bijections.SoftPlus`.
scale_constraint: Bijection with shape matching the Affine bijection for
reparameterizing the scale parameter. Defaults to
:class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to scale_constraint.
"""

shape: tuple[int, ...]
cond_shape: ClassVar[None] = None
loc: Array
_scale: Array
positivity_constraint: AbstractBijection
scale_constraint: AbstractBijection

def __init__(
self,
loc: ArrayLike = 0,
scale: ArrayLike = 1,
scale_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
):
loc, scale = (arraylike_to_array(a, dtype=float) for a in (loc, scale))
self.shape = jnp.broadcast_shapes(loc.shape, scale.shape)
self.loc = jnp.broadcast_to(loc, self.shape)
scale_constraint = _deprecate_positivty(scale_constraint, positivity_constraint)

self.loc, scale = jnp.broadcast_arrays(
*(arraylike_to_array(a, dtype=float) for a in (loc, scale)),
)
self.shape = scale.shape

if positivity_constraint is None:
positivity_constraint = SoftPlus(self.shape)
if scale_constraint is None:
scale_constraint = SoftPlus(self.shape)

self.positivity_constraint = positivity_constraint
self._scale = positivity_constraint.inverse(jnp.broadcast_to(scale, self.shape))
self.scale_constraint = scale_constraint
self._scale = _argcheck_and_reparam_scale(scale, scale_constraint)

def transform(self, x, condition=None):
return x * self.scale + self.loc

def transform_and_log_det(self, x, condition=None):
scale = self.scale
return x * scale + self.loc, jnp.log(scale).sum()
return x * scale + self.loc, jnp.log(jnp.abs(scale)).sum()

def inverse(self, y, condition=None):
return (y - self.loc) / self.scale

def inverse_and_log_det(self, y, condition=None):
scale = self.scale
return (y - self.loc) / scale, -jnp.log(scale).sum()
return (y - self.loc) / scale, -jnp.log(jnp.abs(scale)).sum()

@property
def scale(self):
"""The scale parameter of the affine transformation."""
return self.positivity_constraint.transform(self._scale)
return self.scale_constraint.transform(self._scale)


class Scale(AbstractBijection):
"""Scale transformation ``y = a*x``.
Args:
scale: Scale parameter. Defaults to 1.
positivity_constraint: Bijection with shape matching the Affine bijection, that
maps the scale parameter from an unbounded domain to the positive domain.
Defaults to :class:`~flowjax.bijections.SoftPlus`.
scale_constraint: Bijection with shape matching the Affine bijection for
reparameterizing the scale parameter. Defaults to
:class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to scale_constraint.
"""

_scale: Array
positivity_constraint: AbstractBijection
scale_constraint: AbstractBijection
shape: tuple[int, ...]
cond_shape: ClassVar[None] = None

def __init__(
self,
scale: ArrayLike,
scale_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
):
if positivity_constraint is None:
positivity_constraint = SoftPlus(jnp.shape(scale))
scale_constraint = _deprecate_positivty(scale_constraint, positivity_constraint)

self.positivity_constraint = positivity_constraint
self._scale = positivity_constraint.inverse(scale)
if scale_constraint is None:
scale_constraint = SoftPlus(jnp.shape(scale))
self.shape = jnp.shape(scale)
self.scale_constraint = scale_constraint
self._scale = _argcheck_and_reparam_scale(scale, scale_constraint)

def transform(self, x, condition=None):
return x * self.scale

def transform_and_log_det(self, x, condition=None):
scale = self.scale
return x * scale, jnp.log(scale).sum()
return x * scale, jnp.log(jnp.abs(scale)).sum()

def inverse(self, y, condition=None):
return y / self.scale

def inverse_and_log_det(self, y, condition=None):
scale = self.scale
return y / scale, -jnp.log(scale).sum()
return y / scale, -jnp.log(jnp.abs(scale)).sum()

@property
def scale(self):
"""The scale parameter of the affine transformation."""
return self.positivity_constraint.transform(self._scale)

@property
def shape(self):
return self._scale.shape
return self.scale_constraint.transform(self._scale)


class TriangularAffine(AbstractBijection):
Expand All @@ -133,10 +169,9 @@ class TriangularAffine(AbstractBijection):
lower: Whether the mask should select the lower or upper
triangular matrix (other elements ignored). Defaults to True (lower).
weight_normalisation: If true, carry out weight normalisation.
positivity_constraint: Bijection with shape matching the dimension of the
triangular affine bijection, that maps the diagonal entries of the array
from an unbounded domain to the positive domain. Also used for weight
normalisation parameters, if used. Defaults to SoftPlus.
diag_constraint: Bijection with shape matching diag(arr) for reparameterizing
the diagonal elements. Defaults to :class:`~flowjax.bijections.SoftPlus`.
positivity_constraint: Deprecated alternative to diag_constraint.
"""

shape: tuple[int, ...]
Expand All @@ -145,7 +180,7 @@ class TriangularAffine(AbstractBijection):
diag_idxs: Array
tri_mask: Array
lower: bool
positivity_constraint: AbstractBijection
diag_constraint: AbstractBijection
_arr: Array
_diag: Array
_weight_scale: Array | None
Expand All @@ -157,15 +192,14 @@ def __init__(
*,
lower: bool = True,
weight_normalisation: bool = False,
diag_constraint: AbstractBijection | None = None,
positivity_constraint: AbstractBijection | None = None,
):
diag_constraint = _deprecate_positivty(diag_constraint, positivity_constraint)
loc, arr = (arraylike_to_array(a, dtype=float) for a in (loc, arr))
if (arr.ndim != 2) or (arr.shape[0] != arr.shape[1]):
raise ValueError("arr must be a square, 2-dimensional matrix.")
checkify.check(
jnp.all(jnp.diag(arr) > 0),
"arr diagonal entries must be positive",
)

dim = arr.shape[0]
self.diag_idxs = jnp.diag_indices(dim)
tri_mask = jnp.tril(jnp.ones((dim, dim), dtype=jnp.int32), k=-1)
Expand All @@ -174,18 +208,20 @@ def __init__(

self.shape = (dim,)

if positivity_constraint is None:
positivity_constraint = SoftPlus(self.shape)
if diag_constraint is None:
diag_constraint = SoftPlus(self.shape)

self.positivity_constraint = positivity_constraint
self._diag = positivity_constraint.inverse(jnp.diag(arr))
self.diag_constraint = diag_constraint

# inexact arrays
self.loc = jnp.broadcast_to(loc, (dim,))
self._arr = arr
self._diag = _argcheck_and_reparam_scale(
jnp.diag(arr), diag_constraint, "diagonal values"
)

if weight_normalisation:
self._weight_scale = positivity_constraint.inverse(jnp.ones((dim,)))
self._weight_scale = jnp.full(dim, SoftPlus().inverse(1))
else:
self._weight_scale = None

Expand All @@ -196,13 +232,13 @@ def arr(self):
Applies masking, constrains the diagonal to be positive and (possibly)
applies weight normalisation.
"""
diag = self.positivity_constraint.transform(self._diag)
diag = self.diag_constraint.transform(self._diag)
off_diag = self.tri_mask * self._arr
arr = off_diag.at[self.diag_idxs].set(diag)

if self._weight_scale is not None:
norms = jnp.linalg.norm(arr, axis=1, keepdims=True)
scale = self.positivity_constraint.transform(self._weight_scale)[:, None]
scale = softplus(self._weight_scale)[:, None]
arr = scale * arr / norms

return arr
Expand All @@ -212,15 +248,15 @@ def transform(self, x, condition=None):

def transform_and_log_det(self, x, condition=None):
arr = self.arr
return arr @ x + self.loc, jnp.log(jnp.diag(arr)).sum()
return arr @ x + self.loc, jnp.log(jnp.abs(jnp.diag(arr))).sum()

def inverse(self, y, condition=None):
return solve_triangular(self.arr, y - self.loc, lower=self.lower)

def inverse_and_log_det(self, y, condition=None):
arr = self.arr
x = solve_triangular(arr, y - self.loc, lower=self.lower)
return x, -jnp.log(jnp.diag(arr)).sum()
return x, -jnp.log(jnp.abs(jnp.diag(arr))).sum()


class AdditiveCondition(AbstractBijection):
Expand Down
7 changes: 4 additions & 3 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from math import prod
from typing import ClassVar

import equinox as eqx
import jax.numpy as jnp
from jax import Array
from jax.experimental import checkify
from jax.typing import ArrayLike

from flowjax.bijections.bijection import AbstractBijection
Expand Down Expand Up @@ -68,8 +68,9 @@ class Permute(AbstractBijection):

def __init__(self, permutation: ArrayLike):
permutation = arraylike_to_array(permutation)
checkify.check(
(permutation.ravel().sort() == jnp.arange(permutation.size)).all(),
permutation = eqx.error_if(
permutation,
permutation.ravel().sort() != jnp.arange(permutation.size),
"Invalid permutation array provided.",
)
self.shape = permutation.shape
Expand Down
8 changes: 4 additions & 4 deletions flowjax/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Distributions, including the abstract and concrete classes."""

import inspect
from abc import abstractmethod
from functools import wraps
Expand All @@ -10,7 +11,6 @@
import jax.random as jr
from equinox import AbstractVar
from jax import Array
from jax.experimental import checkify
from jax.lax import stop_gradient
from jax.numpy import linalg
from jax.scipy import stats as jstats
Expand Down Expand Up @@ -469,6 +469,7 @@ def covariance(self):

class _StandardUniform(AbstractDistribution):
r"""Standard Uniform distribution."""

shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

Expand All @@ -494,9 +495,8 @@ class Uniform(AbstractTransformed):

def __init__(self, minval: ArrayLike, maxval: ArrayLike):
minval, maxval = arraylike_to_array(minval), arraylike_to_array(maxval)
checkify.check(
jnp.all(maxval >= minval),
"Minimums must be less than the maximums.",
minval, maxval = eqx.error_if(
(minval, maxval), maxval <= minval, "minval must be less than the maxval."
)
self.base_dist = _StandardUniform(
jnp.broadcast_shapes(minval.shape, maxval.shape),
Expand Down
8 changes: 2 additions & 6 deletions flowjax/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def coupling_flow(
faster `transform` methods, leading to faster `sample`. Defaults to True.
"""
if transformer is None:
transformer = Affine(
positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]),
)
transformer = Affine(scale_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]))

dim = base_dist.shape[-1]

Expand Down Expand Up @@ -128,9 +126,7 @@ def masked_autoregressive_flow(
leading to faster `sample`. Defaults to True.
"""
if transformer is None:
transformer = Affine(
positivity_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]),
)
transformer = Affine(scale_constraint=Chain([SoftPlus(), _PlusConst(1e-2)]))
dim = base_dist.shape[-1]

def make_layer(key): # masked autoregressive layer + permutation
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "B", "D", "COM", "I", "UP", "TRY004", "RET", "PT", "FBT"]
ignore = ["D102", "D105", "D107"]
ignore = ["D102", "D105", "D107", "B028", "COM812"]


[tool.ruff.lint.pydocstyle]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bijections/test_bijection_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
import pytest
from jax.experimental.checkify import JaxRuntimeError

from flowjax.bijections import Affine, Partial, Permute

Expand Down Expand Up @@ -30,5 +30,5 @@ def test_partial(idx, expected):


def test_Permute_argcheck():
with pytest.raises(JaxRuntimeError):
with pytest.raises(jax.lib.xla_extension.XlaRuntimeError):
Permute(jnp.array([0, 0]))
Loading

0 comments on commit 8948273

Please sign in to comment.