Skip to content

Commit

Permalink
Update blackjax/mcmc/barker.py
Browse files Browse the repository at this point in the history
Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
AdrienCorenflos and junpenglao authored Dec 1, 2023
1 parent 6108637 commit 60b7f9e
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,10 @@ def _barker_sample(key, mean, a, scale):
"""

flat_mean, tree_def = tree_flatten(mean)
flat_a, _ = tree_flatten(a)
n_keys = len(flat_mean)

keys = jax.random.split(key, n_keys)
keys = [k for k in keys]
sample = tree_map(
lambda k, m, a: _barker_sample_nd(k, m, a, scale), keys, flat_mean, flat_a
)
# check that the pytrees have the same structure

sample = tree_unflatten(tree_def, sample)

return sample
flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(keys, flat_mean, flat_a, scale)
return unravel_fn(flat_sample)


def _log1pexp(a):
Expand Down

0 comments on commit 60b7f9e

Please sign in to comment.