diff --git a/blackjax/util.py b/blackjax/util.py index 1a7ebcd09..a3a7226a6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -82,6 +82,28 @@ def generate_gaussian_noise( return unravel_fn(mu + linear_map(sigma, sample)) +def generate_unit_vector( + rng_key: PRNGKey, + position: ArrayLikeTree, +) -> Array: + """Generate a random unit vector with output structure that match a given PyTree. + + Parameters + ---------- + rng_key: + The pseudo-random number generator key used to generate random numbers. + position: + PyTree that the structure the output should to match. + + Returns + ------- + Random unit vector that match the structure of position. + """ + p, unravel_fn = ravel_pytree(position) + sample = normal(rng_key, shape=p.shape, dtype=p.dtype) + return unravel_fn(sample / jnp.linalg.norm(sample)) + + def pytree_size(pytree: ArrayLikeTree) -> int: """Return the dimension of the flatten PyTree.""" return sum(jnp.size(value) for value in tree_leaves(pytree)) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index a41877c13..2f5020d00 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -10,6 +10,7 @@ import blackjax.mcmc.integrators as integrators from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step +from blackjax.util import generate_unit_vector def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0): @@ -130,6 +131,9 @@ def kinetic_energy(p): "velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4}, "mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5}, "yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6}, + "noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog}, + "noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan}, + "noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida}, } @@ -224,6 +228,44 @@ def test_esh_momentum_update(self, dims): next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) + @chex.all_variants(with_pmap=False) + @parameterized.parameters( + [ + "noneuclidean_leapfrog", + "noneuclidean_mclachlan", + "noneuclidean_yoshida", + ], + ) + def test_noneuclidean_integrator(self, integrator_name): + integrator = algorithms[integrator_name] + cov = jnp.asarray([[1.0, 0.5], [0.5, 2.0]]) + logdensity_fn = lambda x: stats.multivariate_normal.logpdf( + x, jnp.zeros([2]), cov + ) + + step = self.variant(integrator["algorithm"](logdensity_fn)) + + rng = jax.random.key(4263456) + key0, key1 = jax.random.split(rng, 2) + position_init = jax.random.normal(key0, (2,)) + momentum_init = generate_unit_vector(key1, position_init) + step_size = 0.0001 + initial_state = integrators.new_integrator_state( + logdensity_fn, position_init, momentum_init + ) + + final_state, kinetic_energy_change = jax.lax.scan( + lambda state, _: step(state, step_size), + initial_state, + xs=None, + length=15, + ) + + # Check the conservation of energy. + potential_energy_change = final_state.logdensity - initial_state.logdensity + energy_change = kinetic_energy_change[-1] + potential_energy_change + self.assertAlmostEqual(energy_change, 0, delta=1e-3) + if __name__ == "__main__": absltest.main()