Skip to content

Commit

Permalink
Merge branch 'main' into refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Nov 27, 2023
2 parents 6ea5320 + e0f107f commit a66af60
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
22 changes: 22 additions & 0 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ def generate_gaussian_noise(
return unravel_fn(mu + linear_map(sigma, sample))


def generate_unit_vector(
rng_key: PRNGKey,
position: ArrayLikeTree,
) -> Array:
"""Generate a random unit vector with output structure that match a given PyTree.
Parameters
----------
rng_key:
The pseudo-random number generator key used to generate random numbers.
position:
PyTree that the structure the output should to match.
Returns
-------
Random unit vector that match the structure of position.
"""
p, unravel_fn = ravel_pytree(position)
sample = normal(rng_key, shape=p.shape, dtype=p.dtype)
return unravel_fn(sample / jnp.linalg.norm(sample))


def pytree_size(pytree: ArrayLikeTree) -> int:
"""Return the dimension of the flatten PyTree."""
return sum(jnp.size(value) for value in tree_leaves(pytree))
Expand Down
42 changes: 42 additions & 0 deletions tests/mcmc/test_integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import blackjax.mcmc.integrators as integrators
from blackjax.mcmc.integrators import esh_dynamics_momentum_update_one_step
from blackjax.util import generate_unit_vector


def HarmonicOscillator(inv_mass_matrix, k=1.0, m=1.0):
Expand Down Expand Up @@ -130,6 +131,9 @@ def kinetic_energy(p):
"velocity_verlet": {"algorithm": integrators.velocity_verlet, "precision": 1e-4},
"mclachlan": {"algorithm": integrators.mclachlan, "precision": 1e-5},
"yoshida": {"algorithm": integrators.yoshida, "precision": 1e-6},
"noneuclidean_leapfrog": {"algorithm": integrators.noneuclidean_leapfrog},
"noneuclidean_mclachlan": {"algorithm": integrators.noneuclidean_mclachlan},
"noneuclidean_yoshida": {"algorithm": integrators.noneuclidean_yoshida},
}


Expand Down Expand Up @@ -224,6 +228,44 @@ def test_esh_momentum_update(self, dims):
next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0)
np.testing.assert_array_almost_equal(next_momentum, next_momentum1)

@chex.all_variants(with_pmap=False)
@parameterized.parameters(
[
"noneuclidean_leapfrog",
"noneuclidean_mclachlan",
"noneuclidean_yoshida",
],
)
def test_noneuclidean_integrator(self, integrator_name):
integrator = algorithms[integrator_name]
cov = jnp.asarray([[1.0, 0.5], [0.5, 2.0]])
logdensity_fn = lambda x: stats.multivariate_normal.logpdf(
x, jnp.zeros([2]), cov
)

step = self.variant(integrator["algorithm"](logdensity_fn))

rng = jax.random.key(4263456)
key0, key1 = jax.random.split(rng, 2)
position_init = jax.random.normal(key0, (2,))
momentum_init = generate_unit_vector(key1, position_init)
step_size = 0.0001
initial_state = integrators.new_integrator_state(
logdensity_fn, position_init, momentum_init
)

final_state, kinetic_energy_change = jax.lax.scan(
lambda state, _: step(state, step_size),
initial_state,
xs=None,
length=15,
)

# Check the conservation of energy.
potential_energy_change = final_state.logdensity - initial_state.logdensity
energy_change = kinetic_energy_change[-1] + potential_energy_change
self.assertAlmostEqual(energy_change, 0, delta=1e-3)


if __name__ == "__main__":
absltest.main()

0 comments on commit a66af60

Please sign in to comment.