Skip to content

Commit

Permalink
Merge pull request #188 from danielward27/key_based_loss
Browse files Browse the repository at this point in the history
key_based_loss update
  • Loading branch information
danielward27 authored Oct 14, 2024
2 parents f73c906 + 83c1787 commit c006585
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 191 deletions.
13 changes: 4 additions & 9 deletions docs/examples/conditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "tailflows",
"display_name": "flowjax_env",
"language": "python",
"name": "python3"
},
Expand All @@ -160,15 +160,10 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "e239a143cedf191ff119c9fed46f72dee75a7abaaa2b0d4972963f7b379ced22"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
}
16 changes: 8 additions & 8 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

38 changes: 13 additions & 25 deletions docs/examples/unconditional.ipynb

Large diffs are not rendered by default.

27 changes: 11 additions & 16 deletions docs/examples/variational_inference.ipynb

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions flowjax/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
"""Utilities for training flows, fitting to samples or ysing variational inference."""

from .data_fit import fit_to_data
from .loops import fit_to_data, fit_to_key_based_loss
from .train_utils import step
from .variational_fit import fit_to_variational_target

__all__ = [
"fit_to_data",
"fit_to_variational_target",
"step",
]
__all__ = ["fit_to_key_based_loss", "fit_to_data", "fit_to_variational_target", "step"]
135 changes: 10 additions & 125 deletions flowjax/train/data_fit.py
Original file line number Diff line number Diff line change
@@ -1,131 +1,16 @@
"""Function to fit flows to samples from a distribution."""

from collections.abc import Callable
import warnings

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import optax
from jaxtyping import ArrayLike, PRNGKeyArray, PyTree
from tqdm import tqdm
from flowjax import train

from flowjax import wrappers
from flowjax.train.losses import MaximumLikelihoodLoss
from flowjax.train.train_utils import (
count_fruitless,
get_batches,
step,
train_val_split,
)


def fit_to_data(
key: PRNGKeyArray,
dist: PyTree, # Custom losses may support broader types than AbstractDistribution
x: ArrayLike,
*,
condition: ArrayLike | None = None,
loss_fn: Callable | None = None,
max_epochs: int = 100,
max_patience: int = 5,
batch_size: int = 100,
val_prop: float = 0.1,
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
return_best: bool = True,
show_progress: bool = True,
) -> tuple[PyTree, dict[str, list]]:
r"""Train a distribution (e.g. a flow) to samples from the target distribution.
The distribution can be unconditional :math:`p(x)` or conditional
:math:`p(x|\text{condition})`. Note that the last batch in each epoch is dropped
if truncated (to avoid recompilation). This function can also be used to fit
non-distribution pytrees as long as a compatible loss function is provided.
Args:
key: Jax random key.
dist: The distribution to train.
x: Samples from target distribution.
condition: Conditioning variables. Defaults to None.
loss_fn: Loss function. Defaults to MaximumLikelihoodLoss.
max_epochs: Maximum number of epochs. Defaults to 100.
max_patience: Number of consecutive epochs with no validation loss improvement
after which training is terminated. Defaults to 5.
batch_size: Batch size. Defaults to 100.
val_prop: Proportion of data to use in validation set. Defaults to 0.1.
learning_rate: Adam learning rate. Defaults to 5e-4.
optimizer: Optax optimizer. If provided, this overrides the default Adam
optimizer, and the learning_rate is ignored. Defaults to None.
return_best: Whether the result should use the parameters where the minimum loss
was reached (when True), or the parameters after the last update (when
False). Defaults to True.
show_progress: Whether to show progress bar. Defaults to True.
Returns:
A tuple containing the trained distribution and the losses.
"""
data = (x,) if condition is None else (x, condition)
data = tuple(jnp.asarray(a) for a in data)

if optimizer is None:
optimizer = optax.adam(learning_rate)

if loss_fn is None:
loss_fn = MaximumLikelihoodLoss()

params, static = eqx.partition(
dist,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
def fit_to_data(*args, **kwargs): # TODO deprecate
"""Deprecated import for fit_to_data."""
warnings.warn(
"Importing from data_fit will be deprecated in 17.0.0.. Please import from "
"``flowjax.train.loops`` or ``flowjax.train``.",
DeprecationWarning,
stacklevel=2,
)
best_params = params
opt_state = optimizer.init(params)

# train val split
key, subkey = jr.split(key)
train_data, val_data = train_val_split(subkey, data, val_prop=val_prop)
losses = {"train": [], "val": []}

loop = tqdm(range(max_epochs), disable=not show_progress)

for _ in loop:
# Shuffle data
key, *subkeys = jr.split(key, 3)
train_data = [jr.permutation(subkeys[0], a) for a in train_data]
val_data = [jr.permutation(subkeys[1], a) for a in val_data]

# Train epoch
batch_losses = []
for batch in zip(*get_batches(train_data, batch_size), strict=True):
key, subkey = jr.split(key)
params, opt_state, loss_i = step(
params,
static,
*batch,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
key=subkey,
)
batch_losses.append(loss_i)
losses["train"].append((sum(batch_losses) / len(batch_losses)).item())

# Val epoch
batch_losses = []
for batch in zip(*get_batches(val_data, batch_size), strict=True):
key, subkey = jr.split(key)
loss_i = loss_fn(params, static, *batch, key=subkey)
batch_losses.append(loss_i)
losses["val"].append((sum(batch_losses) / len(batch_losses)).item())

loop.set_postfix({k: v[-1] for k, v in losses.items()})
if losses["val"][-1] == min(losses["val"]):
best_params = params

elif count_fruitless(losses["val"]) > max_patience:
loop.set_postfix_str(f"{loop.postfix} (Max patience reached)")
break

params = best_params if return_best else params
dist = eqx.combine(params, static)
return dist, losses
return train.loops.fit_to_data(*args, **kwargs)
187 changes: 187 additions & 0 deletions flowjax/train/loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Training loops."""

from collections.abc import Callable

import equinox as eqx
import jax.numpy as jnp
import jax.random as jr
import optax
from jaxtyping import ArrayLike, PRNGKeyArray, PyTree, Scalar
from tqdm import tqdm

from flowjax import wrappers
from flowjax.train.losses import MaximumLikelihoodLoss
from flowjax.train.train_utils import (
count_fruitless,
get_batches,
step,
train_val_split,
)


def fit_to_key_based_loss(
key: PRNGKeyArray,
tree: PyTree,
*,
loss_fn: Callable[[PyTree, PyTree, PRNGKeyArray], Scalar],
steps: int,
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
show_progress: bool = True,
):
"""Train a pytree, using a loss with params, static and key as arguments.
This can be used e.g. to fit a distribution using a variational objective, such as
the evidence lower bound.
Args:
key: Jax random key.
tree: PyTree, from which trainable parameters are found using
``equinox.is_inexact_array``.
loss_fn: The loss function to optimize.
steps: The number of optimization steps.
learning_rate: The adam learning rate. Ignored if optimizer is provided.
optimizer: Optax optimizer. Defaults to None.
show_progress: Whether to show progress bar. Defaults to True.
Returns:
A tuple containing the trained pytree and the losses.
"""
if optimizer is None:
optimizer = optax.adam(learning_rate)

params, static = eqx.partition(
tree,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
)
opt_state = optimizer.init(params)

losses = []

keys = tqdm(jr.split(key, steps), disable=not show_progress)

for key in keys:
params, opt_state, loss = step(
params,
static,
key=key,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
)
losses.append(loss.item())
keys.set_postfix({"loss": loss.item()})
return eqx.combine(params, static), losses


def fit_to_data(
key: PRNGKeyArray,
dist: PyTree, # Custom losses may support broader types than AbstractDistribution
x: ArrayLike,
*,
condition: ArrayLike | None = None,
loss_fn: Callable | None = None,
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
max_epochs: int = 100,
max_patience: int = 5,
batch_size: int = 100,
val_prop: float = 0.1,
return_best: bool = True,
show_progress: bool = True,
):
r"""Train a PyTree (e.g. a distribution) to samples from the target.
The model can be unconditional :math:`p(x)` or conditional
:math:`p(x|\text{condition})`. Note that the last batch in each epoch is dropped
if truncated (to avoid recompilation). This function can also be used to fit
non-distribution pytrees as long as a compatible loss function is provided.
Args:
key: Jax random seed.
dist: The pytree to train (usually a distribution).
x: Samples from target distribution.
learning_rate: The learning rate for adam optimizer. Ignored if optimizer is
provided.
optimizer: Optax optimizer. Defaults to None.
condition: Conditioning variables. Defaults to None.
loss_fn: Loss function. Defaults to MaximumLikelihoodLoss.
max_epochs: Maximum number of epochs. Defaults to 100.
max_patience: Number of consecutive epochs with no validation loss improvement
after which training is terminated. Defaults to 5.
batch_size: Batch size. Defaults to 100.
val_prop: Proportion of data to use in validation set. Defaults to 0.1.
return_best: Whether the result should use the parameters where the minimum loss
was reached (when True), or the parameters after the last update (when
False). Defaults to True.
show_progress: Whether to show progress bar. Defaults to True.
Returns:
A tuple containing the trained distribution and the losses.
"""
data = (x,) if condition is None else (x, condition)
data = tuple(jnp.asarray(a) for a in data)

if loss_fn is None:
loss_fn = MaximumLikelihoodLoss()

if optimizer is None:
optimizer = optax.adam(learning_rate)

params, static = eqx.partition(
dist,
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable),
)
best_params = params
opt_state = optimizer.init(params)

# train val split
key, subkey = jr.split(key)
train_data, val_data = train_val_split(subkey, data, val_prop=val_prop)
losses = {"train": [], "val": []}

loop = tqdm(range(max_epochs), disable=not show_progress)

for _ in loop:
# Shuffle data
key, *subkeys = jr.split(key, 3)
train_data = [jr.permutation(subkeys[0], a) for a in train_data]
val_data = [jr.permutation(subkeys[1], a) for a in val_data]

# Train epoch
batch_losses = []
for batch in zip(*get_batches(train_data, batch_size), strict=True):
key, subkey = jr.split(key)
params, opt_state, loss_i = step(
params,
static,
*batch,
optimizer=optimizer,
opt_state=opt_state,
loss_fn=loss_fn,
key=subkey,
)
batch_losses.append(loss_i)
losses["train"].append((sum(batch_losses) / len(batch_losses)).item())

# Val epoch
batch_losses = []
for batch in zip(*get_batches(val_data, batch_size), strict=True):
key, subkey = jr.split(key)
loss_i = loss_fn(params, static, *batch, key=subkey)
batch_losses.append(loss_i)
losses["val"].append((sum(batch_losses) / len(batch_losses)).item())

loop.set_postfix({k: v[-1] for k, v in losses.items()})
if losses["val"][-1] == min(losses["val"]):
best_params = params

elif count_fruitless(losses["val"]) > max_patience:
loop.set_postfix_str(f"{loop.postfix} (Max patience reached)")
break

params = best_params if return_best else params
dist = eqx.combine(params, static)
return dist, losses
7 changes: 7 additions & 0 deletions flowjax/train/variational_fit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Basic training script for fitting a flow using variational inference."""

import warnings
from collections.abc import Callable

import equinox as eqx
Expand Down Expand Up @@ -42,6 +43,12 @@ def fit_to_variational_target(
Returns:
A tuple containing the trained distribution and the losses.
"""
warnings.warn(
"This function will be deprecated in 17.0.0. Please switch to using "
"``flowjax.train.loops.fit_to_key_based_loss``.",
DeprecationWarning,
stacklevel=2,
) # TODO deprecate
if optimizer is None:
optimizer = optax.adam(learning_rate)

Expand Down
Loading

0 comments on commit c006585

Please sign in to comment.