Skip to content

Commit

Permalink
Refactor proposal.py (#603)
Browse files Browse the repository at this point in the history
* Refactor proposal.py

* Fix test

* Fix test 2

* Fix test 3
  • Loading branch information
junpenglao authored Dec 4, 2023
1 parent 41f47d5 commit f49945d
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 284 deletions.
4 changes: 2 additions & 2 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def one_step(carry, rng_key):
new_states, info = jax.vmap(_step_fn)(keys, states)
new_adaptation_state = update(
adaptation_state,
info.proposal.state.position,
info.proposal.state.momentum,
info.proposal.position,
info.proposal.momentum,
states.position,
info.acceptance_rate,
info.is_divergent,
Expand Down
18 changes: 8 additions & 10 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax.tree_util import tree_leaves, tree_map

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey

__all__ = ["BarkerState", "BarkerInfo", "init", "build_kernel", "barker_proposal"]
Expand Down Expand Up @@ -99,9 +100,7 @@ def ratio_proposal_nd(y, x, log_y, log_x):
state.logdensity_grad,
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
log_p_accept = proposal.logdensity - state.logdensity + ratio_proposal
p_accept = jnp.exp(log_p_accept)
return jnp.minimum(1.0, p_accept)
return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
Expand All @@ -119,13 +118,12 @@ def kernel(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
)

p_accept = _compute_acceptance_probability(state, proposed_state)

accept = jax.random.uniform(key_rmh) < p_accept

state = jax.lax.cond(accept, lambda: proposed_state, lambda: state)
info = BarkerInfo(p_accept, accept, proposed_state)
return state, info
log_p_accept = _compute_acceptance_probability(state, proposed_state)
accepted_state, info = static_binomial_sampling(
key_rmh, log_p_accept, state, proposed_state
)
do_accept, p_accept, _ = info
return accepted_state, BarkerInfo(p_accept, do_accept, proposed_state)

return kernel

Expand Down
9 changes: 4 additions & 5 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import nonreversible_slice_sampling
from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey
from blackjax.util import generate_gaussian_noise

Expand Down Expand Up @@ -94,7 +94,6 @@ def build_kernel(
returns a new state of the chain along with information about the transition.
"""
sample_proposal = proposal.nonreversible_slice_sampling

def kernel(
rng_key: PRNGKey,
Expand Down Expand Up @@ -143,7 +142,7 @@ def kernel(
kinetic_energy_fn,
step_size,
divergence_threshold=divergence_threshold,
sample_proposal=sample_proposal,
sample_proposal=nonreversible_slice_sampling,
)

key_momentum, key_noise = jax.random.split(rng_key)
Expand All @@ -158,14 +157,14 @@ def kernel(
)
# Note that ghmc use nonreversible_slice_sampling, which overloads the pattern
# of SampleProposal and do not actually return the acceptance rate.
proposal, info = proposal_generator(slice, integrator_state)
proposal, info, slice_next = proposal_generator(slice, integrator_state)
proposal = hmc.flip_momentum(proposal)
state = GHMCState(
position=proposal.position,
momentum=proposal.momentum,
logdensity=proposal.logdensity,
logdensity_grad=proposal.logdensity_grad,
slice=info.acceptance_rate,
slice=slice_next,
)

return state, info
Expand Down
29 changes: 14 additions & 15 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
import blackjax.mcmc.trajectory as trajectory
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling
from blackjax.mcmc.trajectory import hmc_energy
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey

Expand Down Expand Up @@ -166,7 +166,7 @@ def kernel(
integrator_state = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)
proposal, info = proposal_generator(key_integrator, integrator_state)
proposal, info, _ = proposal_generator(key_integrator, integrator_state)
proposal = HMCState(
proposal.position, proposal.logdensity, proposal.logdensity_grad
)
Expand Down Expand Up @@ -404,7 +404,7 @@ def hmc_proposal(
num_integration_steps: int = 1,
divergence_threshold: float = 1000,
*,
sample_proposal: Callable = proposal.static_binomial_sampling,
sample_proposal: Callable = static_binomial_sampling,
) -> Callable:
"""Vanilla HMC algorithm.
Expand Down Expand Up @@ -433,33 +433,32 @@ def hmc_proposal(
"""
build_trajectory = trajectory.static_integration(integrator)
init_proposal, generate_proposal = proposal.proposal_generator(
hmc_energy(kinetic_energy)
)
hmc_energy_fn = hmc_energy(kinetic_energy)

def generate(
rng_key, state: integrators.IntegratorState
) -> tuple[integrators.IntegratorState, HMCInfo]:
) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]:
"""Generate a new chain state."""
end_state = build_trajectory(state, step_size, num_integration_steps)
end_state = flip_momentum(end_state)
proposal = init_proposal(state)
new_proposal = generate_proposal(proposal.energy, end_state)
is_diverging = -new_proposal.weight > divergence_threshold
sampled_proposal, *info = sample_proposal(rng_key, proposal, new_proposal)
do_accept, p_accept = info
proposal_energy = hmc_energy_fn(state)
new_energy = hmc_energy_fn(end_state)
delta_energy = safe_energy_diff(proposal_energy, new_energy)
is_diverging = -delta_energy > divergence_threshold
sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state)
do_accept, p_accept, other_proposal_info = info

info = HMCInfo(
state.momentum,
p_accept,
do_accept,
is_diverging,
new_proposal.energy,
new_proposal,
new_energy,
end_state,
num_integration_steps,
)

return sampled_proposal.state, info
return sampled_state, info, other_proposal_info

return generate

Expand Down
12 changes: 5 additions & 7 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def transition_energy(state, new_state, step_size):
)
return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot

init_proposal, generate_proposal = proposal.asymmetric_proposal_generator(
compute_acceptance_ratio = proposal.compute_asymmetric_acceptance_ratio(
transition_energy
)
sample_proposal = proposal.static_binomial_sampling
Expand All @@ -106,15 +106,13 @@ def kernel(
new_state = integrator(key_integrator, state, step_size)
new_state = MALAState(*new_state)

proposal = init_proposal(state)
new_proposal = generate_proposal(state, new_state, step_size=step_size)
sampled_proposal, do_accept, p_accept = sample_proposal(
key_rmh, proposal, new_proposal
)
log_p_accept = compute_acceptance_ratio(state, new_state, step_size=step_size)
accepted_state, info = sample_proposal(key_rmh, log_p_accept, state, new_state)
do_accept, p_accept, _ = info

info = MALAInfo(p_accept, do_accept)

return sampled_proposal.state, info
return accepted_state, info

return kernel

Expand Down
14 changes: 8 additions & 6 deletions blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import jax.scipy.linalg as linalg

from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.proposal import static_binomial_sampling
from blackjax.types import Array, PRNGKey

__all__ = ["MarginalState", "MarginalInfo", "init_and_kernel", "mgrad_gaussian"]
Expand Down Expand Up @@ -121,13 +122,14 @@ def step(key: PRNGKey, state: MarginalState, delta):
hxy = jnp.dot(U_x - temp_y, Gamma_3 * U_grad_y)
hyx = jnp.dot(U_y - temp_x, Gamma_3 * U_grad_x)

alpha = jnp.minimum(1, jnp.exp(log_p_y - logdensity + hxy - hyx))
accept = jax.random.uniform(u_key) < alpha

log_p_accept = log_p_y - logdensity + hxy - hyx
proposed_state = MarginalState(y, log_p_y, grad_y, U_y, U_grad_y)
state = jax.lax.cond(accept, lambda _: proposed_state, lambda _: state, None)
info = MarginalInfo(alpha, accept, proposed_state)
return state, info
accepted_state, info = static_binomial_sampling(
u_key, log_p_accept, state, proposed_state
)
do_accept, p_accept, _ = info
info = MarginalInfo(p_accept, do_accept, proposed_state)
return accepted_state, info

def init(position):
logdensity, logdensity_grad = val_and_grad(position)
Expand Down
Loading

0 comments on commit f49945d

Please sign in to comment.