diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index c3534139b..8a7e50f50 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -257,20 +257,10 @@ def _barker_sample(key, mean, a, scale): """ - flat_mean, tree_def = tree_flatten(mean) - flat_a, _ = tree_flatten(a) - n_keys = len(flat_mean) - - keys = jax.random.split(key, n_keys) - keys = [k for k in keys] - sample = tree_map( - lambda k, m, a: _barker_sample_nd(k, m, a, scale), keys, flat_mean, flat_a - ) - # check that the pytrees have the same structure - - sample = tree_unflatten(tree_def, sample) - - return sample + flat_mean, unravel_fn = ravel_pytree(mean) + flat_a, _ = ravel_pytree(a) + flat_sample = _barker_sample_nd(keys, flat_mean, flat_a, scale) + return unravel_fn(flat_sample) def _log1pexp(a):