diff --git a/blackjax/__init__.py b/blackjax/__init__.py index be4ee924b..e3b323738 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -6,6 +6,7 @@ from .adaptation.window_adaptation import window_adaptation from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .mcmc.barker import barker_proposal from .mcmc.elliptical_slice import elliptical_slice from .mcmc.ghmc import ghmc from .mcmc.hmc import dynamic_hmc, hmc @@ -40,6 +41,7 @@ "irmh", "elliptical_slice", "ghmc", + "barker_proposal", "sgld", # stochastic gradient mcmc "sghmc", "sgnht", diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index ced412517..a1e1a42e0 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,4 +1,5 @@ from . import ( + barker, elliptical_slice, ghmc, hmc, @@ -10,6 +11,7 @@ ) __all__ = [ + "barker", "elliptical_slice", "ghmc", "hmc", diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py new file mode 100644 index 000000000..f2095853d --- /dev/null +++ b/blackjax/mcmc/barker.py @@ -0,0 +1,277 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for Barker's proposal with a Gaussian base kernel.""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from jax.scipy import stats +from jax.tree_util import tree_leaves, tree_map + +from blackjax.base import SamplingAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"] + + +class BarkerState(NamedTuple): + """State of the Barker's proposal algorithm. + + The Barker algorithm takes one position of the chain and returns another + position. In order to make computations more efficient, we also store + the current log-probability density as well as the current gradient of the + log-probability density. + + """ + + position: ArrayTree + logdensity: float + logdensity_grad: ArrayTree + + +class BarkerInfo(NamedTuple): + """Additional information on the Barker's proposal kernel transition. + + This additional information can be used for debugging or computing + diagnostics. + + proposal + The proposal that was sampled. + acceptance_rate + The acceptance rate of the transition. + is_accepted + Whether the proposed position was accepted or the original position + was returned. + + """ + + acceptance_rate: float + is_accepted: bool + proposal: BarkerState + + +def init(position: ArrayLikeTree, logdensity_fn: Callable) -> BarkerState: + grad_fn = jax.value_and_grad(logdensity_fn) + logdensity, logdensity_grad = grad_fn(position) + return BarkerState(position, logdensity, logdensity_grad) + + +def build_kernel(): + """Build a Barker's proposal kernel. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + def _compute_acceptance_probability( + state: BarkerState, + proposal: BarkerState, + ) -> float: + """Compute the acceptance probability of the Barker's proposal kernel.""" + + def ratio_proposal_nd(y, x, log_y, log_x): + num = -_log1pexp(-log_y * (x - y)) + den = -_log1pexp(-log_x * (y - x)) + + return jnp.sum(num - den) + + ratios_proposals = tree_map( + ratio_proposal_nd, + proposal.position, + state.position, + proposal.logdensity_grad, + state.logdensity_grad, + ) + ratio_proposal = sum(tree_leaves(ratios_proposals)) + log_p_accept = proposal.logdensity - state.logdensity + ratio_proposal + p_accept = jnp.exp(log_p_accept) + return jnp.minimum(1.0, p_accept) + + def kernel( + rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float + ) -> tuple[BarkerState, BarkerInfo]: + """Generate a new sample with the MALA kernel.""" + grad_fn = jax.value_and_grad(logdensity_fn) + + key_sample, key_rmh = jax.random.split(rng_key) + + proposed_pos = _barker_sample( + key_sample, state.position, state.logdensity_grad, step_size + ) + proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos) + proposed_state = BarkerState( + proposed_pos, proposed_logdensity, proposed_logdensity_grad + ) + + p_accept = _compute_acceptance_probability(state, proposed_state) + + accept = jax.random.uniform(key_rmh) < p_accept + + state = jax.lax.cond(accept, lambda: proposed_state, lambda: state) + info = BarkerInfo(p_accept, accept, proposed_state) + return state, info + + return kernel + + +class barker_proposal: + """Implements the (basic) user interface for the Barker's proposal kernel with a Gaussian base kernel. + + The general Barker kernel builder (:meth:`blackjax.mcmc.barker.build_kernel`, alias `blackjax.barker.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.barker` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new Barker kernel can be initialized and used with the following code: + + .. code:: + + barker = blackjax.barker(logdensity_fn, step_size) + state = barker.init(position) + new_state, info = barker.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(barker.step) + new_state, info = step(rng_key, state) + + Should you need to you can always use the base kernel directly: + + .. code:: + + kernel = blackjax.barker.build_kernel(logdensity_fn) + state = blackjax.barker.init(position, logdensity_fn) + state, info = kernel(rng_key, state, logdensity_fn, step_size) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + + Returns + ------- + A ``SamplingAlgorithm``. + + """ + + init = staticmethod(init) + build_kernel = staticmethod(build_kernel) + + def __new__( # type: ignore[misc] + cls, + logdensity_fn: Callable, + step_size: float, + ) -> SamplingAlgorithm: + kernel = cls.build_kernel() + + def init_fn(position: ArrayLikeTree): + return cls.init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel(rng_key, state, logdensity_fn, step_size) + + return SamplingAlgorithm(init_fn, step_fn) + + +def _barker_sample_nd(key, mean, a, scale): + """ + Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function: + + .. math:: + p(x; \\mu, a, \\sigma) = 2 \frac{N(x; \\mu, \\sigma^2)}{1 + \\exp(-a (x - \\mu)} + + where :math:`N(x; \\mu, \\sigma^2)` is the normal distribution with mean :math:`\\mu` and standard deviation :math:`\\sigma`. + The multivariate Barker's proposal distribution is the product of one-dimensional Barker's proposal distributions. + + + Parameters + ---------- + key + A PRNG key. + mean + The mean of the normal distribution, an Array. This corresponds to :math:`\\mu` in the equation above. + a + The parameter :math:`a` in the equation above, an Array. This is a skewness parameter. + scale + The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above. + It encodes the step size of the proposal. + + Returns + ------- + A sample from the Barker's multidimensional proposal distribution. + + """ + + key1, key2 = jax.random.split(key) + z = scale * jax.random.normal(key1, shape=mean.shape) + + # Sample b=1 with probability p and 0 with probability 1 - p where + # p = 1 / (1 + exp(-a * (z - mean))) + log_p = -_log1pexp(-a * z) + b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape) + + # return mean + z if b == 1 else mean - z + return mean + b * z - (1 - b) * z + + +def _barker_sample(key, mean, a, scale): + r""" + Sample from a multivariate Barker's proposal distribution for PyTrees. + + Parameters + ---------- + key + A PRNG key. + mean + The mean of the normal distribution, a PyTree. This corresponds to :math:`\mu` in the equation above. + a + The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. + scale + The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above. + It encodes the step size of the proposal. + + """ + + flat_mean, unravel_fn = ravel_pytree(mean) + flat_a, _ = ravel_pytree(a) + flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale) + return unravel_fn(flat_sample) + + +def _log1pexp(a): + return jnp.log1p(jnp.exp(a)) + + +def _barker_logpdf(x, mean, a, scale): + logpdf = jnp.log(2) + stats.norm.logpdf(x, mean, scale) - _log1pexp(-a * (x - mean)) + return logpdf + + +def _barker_pdf(x, mean, a, scale): + return jnp.exp(_barker_logpdf(x, mean, a, scale)) diff --git a/tests/mcmc/test_barker.py b/tests/mcmc/test_barker.py new file mode 100644 index 000000000..5c227c4cb --- /dev/null +++ b/tests/mcmc/test_barker.py @@ -0,0 +1,55 @@ +import chex +import jax +import jax.numpy as jnp +from absl.testing import absltest, parameterized + +from blackjax.mcmc.barker import _barker_pdf, _barker_sample_nd + + +class BarkerSamplingTest(chex.TestCase): + @parameterized.parameters([1234, 5678]) + def test_nd(self, seed): + n_samples = 100_000 + + key = jax.random.key(seed) + m, a, scale = ( + jnp.array([1.0, 0.5, 0.0, 0.0]), + jnp.array([1.0, -2.0, 10.0, 0.0]), + 0.5, + ) + + keys = jax.random.split(key, n_samples) + samples = jax.vmap(lambda k: _barker_sample_nd(k, m, a, scale))(keys) + # Check that the emprical mean and the mean computed as sum(x * p(x) dx) are close + _test_samples_vs_pdf(samples, lambda x: _barker_pdf(x, m, a, scale)) + + +def _test_samples_vs_pdf(samples, pdf): + samples_mean = jnp.mean(samples, 0) + samples_squrared_mean = jnp.mean(samples**2, 0) + linspace = jnp.linspace(-10, 10, 50_000) + + diff = jnp.diff(linspace, axis=0) + + # trapezoidal rule + pdf_mean = 0.5 * jnp.sum( + linspace[1:, None] * pdf(linspace[1:, None]) * diff[:, None], 0 + ) + pdf_mean += 0.5 * jnp.sum( + linspace[:-1, None] * pdf(linspace[:-1, None]) * diff[:, None], 0 + ) + pdf_squared_mean = 0.5 * jnp.sum( + linspace[1:, None] ** 2 * pdf(linspace[1:, None]) * diff[:, None], 0 + ) + pdf_squared_mean += 0.5 * jnp.sum( + linspace[:-1, None] ** 2 * pdf(linspace[:-1, None]) * diff[:, None], 0 + ) + + chex.assert_trees_all_close(samples_mean, pdf_mean, atol=1e-2, rtol=1e-2) + chex.assert_trees_all_close( + samples_squrared_mean, pdf_squared_mean, atol=1e-2, rtol=1e-2 + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7770b55a1..ebffcffc7 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -266,6 +266,27 @@ def test_chees(self, jitter_generator): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + def test_barker(self): + """Test the Barker kernel.""" + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logposterior_fn = lambda x: logposterior_fn_(**x) + + barker = blackjax.barker_proposal(logposterior_fn, 1e-1) + state = barker.init({"coefs": 1.0, "log_scale": 1.0}) + states = inference_loop(barker.step, 10_000, inference_key, state) + + coefs_samples = states.position["coefs"][3000:] + scale_samples = np.exp(states.position["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + class SGMCMCTest(chex.TestCase): """Test sampling of a linear regression model.""" @@ -526,6 +547,13 @@ def test_latent_gaussian(self): "num_sampling_steps": 6000, "burnin": 1_000, }, + { + "algorithm": blackjax.barker_proposal, + "initial_position": 1.0, + "parameters": {"step_size": 1.5}, + "num_sampling_steps": 20_000, + "burnin": 2_000, + }, ] @@ -575,7 +603,6 @@ def test_univariate_normal( ) else: samples = states.position[burnin:] - np.testing.assert_allclose(np.mean(samples), 1.0, rtol=1e-1) np.testing.assert_allclose(np.var(samples), 4.0, rtol=1e-1) @@ -607,6 +634,11 @@ def test_univariate_normal( "parameters": {"step_size": 0.85}, "is_mass_matrix_diagonal": False, }, + { + "algorithm": blackjax.barker_proposal, + "parameters": {"step_size": 0.5}, + "is_mass_matrix_diagonal": None, + }, ] @@ -660,15 +692,18 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): true_rho, true_cov, ) = self.generate_multivariate_target(None) - if is_mass_matrix_diagonal: - inverse_mass_matrix = true_scale**2 + if is_mass_matrix_diagonal is not None: + if is_mass_matrix_diagonal: + inverse_mass_matrix = true_scale**2 + else: + inverse_mass_matrix = true_cov + kernel = algorithm( + logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, + **parameters, + ) else: - inverse_mass_matrix = true_cov - kernel = algorithm( - logdensity_fn, - inverse_mass_matrix=inverse_mass_matrix, - **parameters, - ) + kernel = algorithm(logdensity_fn, **parameters) num_chains = 10 initial_positions = jax.random.normal(pos_init_key, [num_chains, 2])