Skip to content

Commit

Permalink
Implements the Schrödinger-Föllmer algorithm. (#602)
Browse files Browse the repository at this point in the history
* Implements the Schrödinger-Föllmer algorithm.

* Delete unused code

* Fix untested code

* Fix untested code

* make sure code is covered. There's no logic tested.

* Update blackjax/vi/schrodinger_follmer.py

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/vi/schrodinger_follmer.py

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/vi/schrodinger_follmer.py

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/vi/schrodinger_follmer.py

Co-authored-by: Junpeng Lao <[email protected]>

* Cosmetic changes

* add chex variants

* add chex variants

* add chex variants

* adding dtype to be sure.

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
AdrienCorenflos and junpenglao authored Dec 4, 2023
1 parent 3845635 commit 41f47d5
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 2 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .smc.tempered import tempered_smc
from .vi.meanfield_vi import meanfield_vi
from .vi.pathfinder import pathfinder
from .vi.schrodinger_follmer import schrodinger_follmer
from .vi.svgd import svgd

__all__ = [
Expand Down Expand Up @@ -54,6 +55,7 @@
"tempered_smc",
"meanfield_vi", # variational inference
"pathfinder",
"schrodinger_follmer",
"svgd",
"ess", # diagnostics
"rhat",
Expand Down
4 changes: 2 additions & 2 deletions blackjax/vi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import meanfield_vi, pathfinder, svgd
from . import meanfield_vi, pathfinder, schrodinger_follmer, svgd

__all__ = ["pathfinder", "meanfield_vi", "svgd"]
__all__ = ["pathfinder", "meanfield_vi", "svgd", "schrodinger_follmer"]
223 changes: 223 additions & 0 deletions blackjax/vi/schrodinger_follmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import jax.random
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_leaves
from jax.typing import ArrayLike

from blackjax.base import VIAlgorithm
from blackjax.types import ArrayLikeTree, PRNGKey

__all__ = ["SchrodingerFollmerState", "sample", "init", "step"]


class SchrodingerFollmerState(NamedTuple):
"""State of the Schrödinger-Föllmer algorithm.
The Schrödinger-Föllmer algorithm gets samples from the target distribution by
approximating the target distribution as the terminal value of a stochastic differential
equation (SDE) with a drift term that is evaluated under the running samples.
position:
position of the sample
time:
Current integration time of the SDE
"""

position: ArrayLikeTree
time: ArrayLike


class SchrodingerFollmerInfo(NamedTuple):
"""Extra information returned by the Schrodinger Follmer algorithm.
drift:
Approximation of the drift term of the SDE
"""

drift: ArrayLikeTree


def init(example_position: ArrayLikeTree) -> SchrodingerFollmerState:
zero = jax.tree_map(jnp.zeros_like, example_position)
return SchrodingerFollmerState(zero, 0.0)


def step(
rng_key: PRNGKey,
state: SchrodingerFollmerState,
logdensity_fn: Callable,
step_size: float,
n_samples: int,
) -> Tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]:
"""
Runs one step of the Schrödinger-Föllmer algorithm. As per the paper, we only allow for Euler-Maruyama integration.
It is likely possible to generalize this to other integration schemes but is not considered in the original work
and we therefore do not consider it here.
Note that we use the version with Stein's lemma as computing the gradient of the *density* is typically unstable.
Parameters
----------
rng_key
PRNG key
state
Current state of the algorithm
logdensity_fn
Log-density of the target distribution
step_size
Step size of the integration scheme
n_samples
Number of samples to use to approximate the drift term
"""

drift_key, sde_key = jax.random.split(rng_key)

ravelled_position, unravel_fn = ravel_pytree(state.position)
scale = jnp.sqrt(1 - state.time)

eps_drift = jax.random.normal(drift_key, (n_samples,) + ravelled_position.shape)
eps_drift = jax.vmap(unravel_fn)(eps_drift)

perturbed_position = jax.tree_map(
lambda a, b: a[None, ...] + scale * b, state.position, eps_drift
)

log_pdf = jax.vmap(_log_fn_corrected, in_axes=[0, None])(
perturbed_position, logdensity_fn
)
log_pdf -= jnp.max(log_pdf, axis=0, keepdims=True)
pdf = jnp.exp(log_pdf)

num = jax.tree_map(lambda a: pdf @ a, eps_drift)
den = scale * jnp.sum(pdf, axis=0)

drift = jax.tree_map(lambda a: a / den, num)

eps_sde = jax.random.normal(sde_key, ravelled_position.shape)
eps_sde = unravel_fn(eps_sde)
next_position = jax.tree_map(
lambda a, b, c: a + step_size * b + step_size**0.5 * c,
state.position,
drift,
eps_sde,
)
next_state = SchrodingerFollmerState(next_position, state.time + step_size)
return next_state, SchrodingerFollmerInfo(drift)


def sample(
rng_key: PRNGKey,
initial_state: SchrodingerFollmerState,
log_density_fn: Callable,
n_steps: int,
n_inner_samples,
n_samples: int = 1,
):
"""
Samples from the target distribution using the Schrödinger-Föllmer algorithm.
Parameters
----------
rng_key
PRNG key
initial_state
Current state of the algorithm
log_density_fn
Log-density of the target distribution
n_steps
Number of steps to run the algorithm for
n_inner_samples
Number of samples to use to approximate the drift term
n_samples
Number of samples to draw
"""
dt = 1.0 / n_steps

initial_position = initial_state.position
initial_positions = jax.tree_map(
lambda a: jnp.zeros([n_samples, *a.shape], dtype=a.dtype), initial_position
)
initial_states = SchrodingerFollmerState(initial_positions, jnp.zeros((n_samples,)))

def body(_, carry):
key, states = carry
keys = jax.random.split(key, 1 + n_samples)
states, _ = jax.vmap(step, [0, 0, None, None, None])(
keys[1:], states, log_density_fn, dt, n_inner_samples
)
return keys[0], states

_, final_states = jax.lax.fori_loop(0, n_steps, body, (rng_key, initial_states))

return final_states


def _log_fn_corrected(position, logdensity_fn):
"""
The Schrödinger-Föllmer algorithm requires the log-density to be given with respect to a standard Gaussian base measure
but the log-density function passed to the algorithm in BlackJAX is typically given with respect to the Borel measure.
This corrects the gradient of the log-density function to account for this.
"""
log_pdf_val = logdensity_fn(position)
norm = jax.tree_map(lambda a: 0.5 * jnp.sum(a**2), position)
norm = sum(tree_leaves(norm))
return log_pdf_val + norm


class schrodinger_follmer:
"""Implements the (basic) user interface for the Schrödinger-Föllmer algortithm :cite:p:`huang2021schrodingerfollmer`.
The Schrödinger-Föllmer algorithm obtains (approximate) samples from the target distribution by means of a diffusion with
approximated drifts.
Parameters
----------
logdensity_fn
A function that represents the log-density of the model we want
to sample from.
n_steps
Number of steps used in the SDE
n_inner_samples
Number of samples used to approximate the drift term
Returns
-------
A ``VIAlgorithm``.
"""

init = staticmethod(init)
step = staticmethod(step)
sample = staticmethod(sample)

def __new__(cls, logdensity_fn: Callable, n_steps: int, n_inner_samples: int) -> VIAlgorithm: # type: ignore[misc]
def init_fn(position: ArrayLikeTree):
return cls.init(position)

def step_fn(
rng_key: PRNGKey, state: SchrodingerFollmerState
) -> tuple[SchrodingerFollmerState, SchrodingerFollmerInfo]:
return cls.step(rng_key, state, logdensity_fn, 1 / n_steps, n_inner_samples)

def sample_fn(rng_key: PRNGKey, state: SchrodingerFollmerState, n_samples: int):
return cls.sample(
rng_key, state, logdensity_fn, n_steps, n_inner_samples, n_samples
)

return VIAlgorithm(init_fn, step_fn, sample_fn)
10 changes: 10 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,13 @@ @article{Livingstone2022Barker
url = {https://doi.org/10.1111/rssb.12482},
eprint = {https://academic.oup.com/jrsssb/article-pdf/84/2/496/49322274/jrsssb\_84\_2\_496.pdf},
}


@misc{huang2021schrodingerfollmer,
title={Schr{\"o}dinger-F{\"o}llmer Sampler: Sampling without Ergodicity},
author={Jian Huang and Yuling Jiao and Lican Kang and Xu Liao and Jin Liu and Yanyan Liu},
year={2021},
eprint={2106.10880},
archivePrefix={arXiv},
primaryClass={stat.CO}
}
87 changes: 87 additions & 0 deletions tests/vi/test_schrodinger_follmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import functools

import chex
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
from absl.testing import absltest

from blackjax.vi.schrodinger_follmer import schrodinger_follmer


class SchrodingerFollmerTest(chex.TestCase):
def setUp(self):
super().setUp()
self.key = jax.random.key(1)

@chex.all_variants(with_pmap=True)
def test_recover_posterior(self):
"""Simple Normal mean test"""

ndim = 2

rng_key_chol, rng_key_observed, rng_key_init = jax.random.split(self.key, 3)
L = jnp.tril(jax.random.normal(rng_key_chol, (ndim, ndim)))
true_mu = jnp.arange(ndim)
true_cov = L @ L.T
true_prec = jnp.linalg.pinv(true_cov)

def logp_posterior_conjugate_normal_model(
observed, prior_mu, prior_prec, true_prec
):
n = observed.shape[0]
posterior_cov = jnp.linalg.inv(prior_prec + n * true_prec)
posterior_mu = (
posterior_cov
@ (
prior_prec @ prior_mu[:, None]
+ n * true_prec @ observed.mean(0)[:, None]
)
)[:, 0]
return posterior_mu

def logp_unnormalized_posterior(x, observed, prior_mu, prior_prec, true_cov):
logp = 0.0
logp += stats.multivariate_normal.logpdf(x, prior_mu, prior_prec)
logp += stats.multivariate_normal.logpdf(observed, x, true_cov).sum()
return logp

prior_mu = jnp.zeros(ndim)
prior_prec = jnp.eye(ndim)

# Simulate the data
observed = jax.random.multivariate_normal(
rng_key_observed, true_mu, true_cov, shape=(10_000,)
)

logp_model = functools.partial(
logp_unnormalized_posterior,
observed=observed,
prior_mu=prior_mu,
prior_prec=prior_prec,
true_cov=true_cov,
)

initial_position = jnp.zeros((ndim,))
posterior_mu = logp_posterior_conjugate_normal_model(
observed, prior_mu, prior_prec, true_prec
)

schrodinger_follmer_algo = schrodinger_follmer(logp_model, 50, 25)

initial_state = schrodinger_follmer_algo.init(initial_position)
schrodinger_follmer_algo_sample = self.variant(
lambda k, s: schrodinger_follmer_algo.sample(k, s, 100)
)
sampled_states = schrodinger_follmer_algo_sample(rng_key_init, initial_state)
sampled_position = sampled_states.position
chex.assert_trees_all_close(
sampled_position.mean(0), posterior_mu, rtol=1e-2, atol=1e-1
)

# make sure basic interface is independently covered
_ = schrodinger_follmer_algo.step(rng_key_init, initial_state)


if __name__ == "__main__":
absltest.main()

0 comments on commit 41f47d5

Please sign in to comment.