Skip to content

Commit

Permalink
Still cosmetic fixed as suggested by Junpeng
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Dec 1, 2023
1 parent 14bd402 commit 6108637
Showing 1 changed file with 13 additions and 33 deletions.
46 changes: 13 additions & 33 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax
import jax.numpy as jnp
from jax.scipy import stats
from jax.tree_util import tree_flatten
from jax.tree_util import tree_flatten, tree_leaves, tree_map, tree_unflatten

from blackjax.base import SamplingAlgorithm
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
Expand Down Expand Up @@ -79,36 +79,26 @@ def build_kernel():
"""

def _compute_acceptance_probability(
logdensity: float,
logdensity_proposal: float,
logdensity_grad: ArrayTree,
logdensity_grad_proposal: ArrayTree,
position: ArrayTree,
position_proposal: ArrayTree,
scale: float,
state: BarkerState,
proposal: BarkerState,
) -> float:
"""Compute the acceptance probability of the Barker's proposal kernel."""

logdensity_grad, _ = tree_flatten(logdensity_grad)
logdensity_grad_proposal, _ = tree_flatten(logdensity_grad_proposal)
position, _ = tree_flatten(position)
position_proposal, _ = tree_flatten(position_proposal)

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))

return jnp.sum(num - den)

ratios_proposals = map(
ratios_proposals = tree_map(
ratio_proposal_nd,
position_proposal,
position,
logdensity_grad_proposal,
logdensity_grad,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
)
ratio_proposal = sum(ratios_proposals)
log_p_accept = logdensity_proposal - logdensity + ratio_proposal
ratio_proposal = sum(tree_leaves(ratios_proposals))
log_p_accept = proposal.logdensity - state.logdensity + ratio_proposal
p_accept = jnp.exp(log_p_accept)
return jnp.minimum(1.0, p_accept)

Expand All @@ -124,20 +114,12 @@ def kernel(
key_sample, state.position, state.logdensity_grad, step_size
)
proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)

p_accept = _compute_acceptance_probability(
state.logdensity,
proposed_logdensity,
state.logdensity_grad,
proposed_logdensity_grad,
state.position,
proposed_pos,
step_size,
)

proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

p_accept = _compute_acceptance_probability(state, proposed_state)

accept = jax.random.uniform(key_rmh) < p_accept

state = jax.lax.cond(accept, lambda: proposed_state, lambda: state)
Expand Down Expand Up @@ -275,8 +257,6 @@ def _barker_sample(key, mean, a, scale):
"""

from jax.tree_util import tree_flatten, tree_map, tree_unflatten

flat_mean, tree_def = tree_flatten(mean)
flat_a, _ = tree_flatten(a)
n_keys = len(flat_mean)
Expand Down

0 comments on commit 6108637

Please sign in to comment.