Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao authored Dec 1, 2023
2 parents 3a27622 + 9713452 commit d5bd76c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 23 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# BlackJAX
![CI](https://github.com/blackjax-devs/blackjax/workflows/Run%20tests/badge.svg?branch=main)
[![codecov](https://codecov.io/gh/blackjax-devs/blackjax/branch/main/graph/badge.svg)](https://codecov.io/gh/blackjax-devs/blackjax)
![Continuous integration](https://github.com/blackjax-devs/blackjax/actions/workflows/test.yml/badge.svg)
![codecov](https://codecov.io/gh/blackjax-devs/blackjax/branch/main/graph/badge.svg)
![PyPI version](https://img.shields.io/pypi/v/blackjax)


## What is BlackJAX?
Expand Down
20 changes: 9 additions & 11 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,9 @@ def compute_parameters(
trajectory_gradients = (
jitter_generator(random_generator_arg)
* trajectory_length
* (
jax.vmap(lambda p: jnp.dot(p, p))(proposals_matrix)
- jax.vmap(lambda p: jnp.dot(p, p))(initials_matrix)
)
* jax.vmap(lambda p, m: jnp.dot(p, m))(proposals_matrix, momentums_matrix)
* jax.vmap(
lambda pm, im, mm: (jnp.dot(pm, pm) - jnp.dot(im, im)) * jnp.dot(pm, mm)
)(proposals_matrix, initials_matrix, momentums_matrix)
)
trajectory_gradient = jnp.sum(
acceptance_probabilities * trajectory_gradients, where=~is_divergent
Expand Down Expand Up @@ -273,7 +271,7 @@ def update(


def chees_adaptation(
logprob_fn: Callable,
logdensity_fn: Callable,
num_chains: int,
*,
jitter_generator: Optional[Callable] = None,
Expand Down Expand Up @@ -308,7 +306,7 @@ def chees_adaptation(
.. code::
warmup = blackjax.chees_adaptation(logprob_fn, num_chains)
warmup = blackjax.chees_adaptation(logdensity_fn, num_chains)
key_warmup, key_sample = jax.random.split(rng_key)
optim = optax.adam(learning_rate)
(last_states, parameters), _ = warmup.run(
Expand All @@ -318,12 +316,12 @@ def chees_adaptation(
optim,
num_warmup_steps,
)
kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step
kernel = blackjax.dynamic_hmc(logdensity_fn, **parameters).step
new_states, info = jax.vmap(kernel)(key_sample, last_states)
Parameters
----------
logprob_fn
logdensity_fn
The log density probability density function from which we wish to sample.
num_chains
Number of chains used for cross-chain warm-up training.
Expand Down Expand Up @@ -399,7 +397,7 @@ def one_step(carry, rng_key):
keys = jax.random.split(rng_key, num_chains)
_step_fn = partial(
step_fn,
logdensity_fn=logprob_fn,
logdensity_fn=logdensity_fn,
step_size=adaptation_state.step_size,
inverse_mass_matrix=jnp.ones(num_dim),
trajectory_length_adjusted=adaptation_state.trajectory_length
Expand All @@ -422,7 +420,7 @@ def one_step(carry, rng_key):
)

batch_init = jax.vmap(
lambda p: hmc.init_dynamic(p, logprob_fn, init_random_arg)
lambda p: hmc.init_dynamic(p, logdensity_fn, init_random_arg)
)
init_states = batch_init(positions)
init_adaptation_state = init(init_random_arg, step_size)
Expand Down
12 changes: 7 additions & 5 deletions blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,16 @@ def kernel(
integrator_state = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)
# Note that ghmc use nonreversible_slice_sampling, which overloads the pattern
# of SampleProposal and do not actually return the acceptance rate.
proposal, info = proposal_generator(slice, integrator_state)
proposal = hmc.flip_momentum(proposal)
state = GHMCState(
proposal.position,
proposal.momentum,
proposal.logdensity,
proposal.logdensity_grad,
info.acceptance_rate,
position=proposal.position,
momentum=proposal.momentum,
logdensity=proposal.logdensity,
logdensity_grad=proposal.logdensity_grad,
slice=info.acceptance_rate,
)

return state, info
Expand Down
27 changes: 22 additions & 5 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Protocol

import jax
import jax.numpy as jnp

from blackjax.types import Array, PRNGKey

TrajectoryState = NamedTuple


Expand Down Expand Up @@ -160,12 +162,21 @@ def update(
return new, update


class SampleProposal(Protocol):
def __call__(
self, rng_key: Array, proposal: Proposal, new_proposal: Proposal
) -> Proposal:
...


# --------------------------------------------------------------------
# STATIC SAMPLING
# --------------------------------------------------------------------


def static_binomial_sampling(rng_key, proposal, new_proposal):
def static_binomial_sampling(
rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal
) -> Proposal:
"""Accept or reject a proposal.
In the static setting, the probability with which the new proposal is
Expand Down Expand Up @@ -195,7 +206,9 @@ def static_binomial_sampling(rng_key, proposal, new_proposal):
# --------------------------------------------------------------------


def progressive_uniform_sampling(rng_key, proposal, new_proposal):
def progressive_uniform_sampling(
rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal
) -> Proposal:
# Using expit to compute exp(w1) / (exp(w0) + exp(w1))
p_accept = jax.scipy.special.expit(new_proposal.weight - proposal.weight)
do_accept = jax.random.bernoulli(rng_key, p_accept)
Expand All @@ -222,7 +235,9 @@ def progressive_uniform_sampling(rng_key, proposal, new_proposal):
)


def progressive_biased_sampling(rng_key, proposal, new_proposal):
def progressive_biased_sampling(
rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal
) -> Proposal:
"""Baised proposal sampling :cite:p:`betancourt2017conceptual`.
Unlike uniform sampling, biased sampling favors new proposals. It thus
Expand Down Expand Up @@ -259,7 +274,9 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal):
# --------------------------------------------------------------------


def nonreversible_slice_sampling(slice, proposal, new_proposal):
def nonreversible_slice_sampling(
slice: Array, proposal: Proposal, new_proposal: Proposal
) -> Proposal:
"""Slice sampling for non-reversible Metropolis-Hasting update.
Performs a non-reversible update of a uniform [0, 1] value
Expand Down

0 comments on commit d5bd76c

Please sign in to comment.