From 7a3081879bfeae45e34e927c057c0ba872150363 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Wed, 29 Nov 2023 10:06:56 -0500 Subject: [PATCH 1/4] Name arguments consistently across classes. (#593) --- blackjax/adaptation/chees_adaptation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index faceaba7e..29dac5475 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -273,7 +273,7 @@ def update( def chees_adaptation( - logprob_fn: Callable, + logdensity_fn: Callable, num_chains: int, *, jitter_generator: Optional[Callable] = None, @@ -308,7 +308,7 @@ def chees_adaptation( .. code:: - warmup = blackjax.chees_adaptation(logprob_fn, num_chains) + warmup = blackjax.chees_adaptation(logdensity_fn, num_chains) key_warmup, key_sample = jax.random.split(rng_key) optim = optax.adam(learning_rate) (last_states, parameters), _ = warmup.run( @@ -318,12 +318,12 @@ def chees_adaptation( optim, num_warmup_steps, ) - kernel = blackjax.dynamic_hmc(logprob_fn, **parameters).step + kernel = blackjax.dynamic_hmc(logdensity_fn, **parameters).step new_states, info = jax.vmap(kernel)(key_sample, last_states) Parameters ---------- - logprob_fn + logdensity_fn The log density probability density function from which we wish to sample. num_chains Number of chains used for cross-chain warm-up training. @@ -399,7 +399,7 @@ def one_step(carry, rng_key): keys = jax.random.split(rng_key, num_chains) _step_fn = partial( step_fn, - logdensity_fn=logprob_fn, + logdensity_fn=logdensity_fn, step_size=adaptation_state.step_size, inverse_mass_matrix=jnp.ones(num_dim), trajectory_length_adjusted=adaptation_state.trajectory_length @@ -422,7 +422,7 @@ def one_step(carry, rng_key): ) batch_init = jax.vmap( - lambda p: hmc.init_dynamic(p, logprob_fn, init_random_arg) + lambda p: hmc.init_dynamic(p, logdensity_fn, init_random_arg) ) init_states = batch_init(positions) init_adaptation_state = init(init_random_arg, step_size) From 5569fafb478faa8f300f38478a0a37ce7fa4f1a2 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Wed, 29 Nov 2023 16:38:34 +0100 Subject: [PATCH 2/4] Small clean up --- blackjax/adaptation/chees_adaptation.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index 29dac5475..fce61be77 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -178,11 +178,9 @@ def compute_parameters( trajectory_gradients = ( jitter_generator(random_generator_arg) * trajectory_length - * ( - jax.vmap(lambda p: jnp.dot(p, p))(proposals_matrix) - - jax.vmap(lambda p: jnp.dot(p, p))(initials_matrix) - ) - * jax.vmap(lambda p, m: jnp.dot(p, m))(proposals_matrix, momentums_matrix) + * jax.vmap( + lambda pm, im, mm: (jnp.dot(pm, pm) - jnp.dot(im, im)) * jnp.dot(pm, mm) + )(proposals_matrix, initials_matrix, momentums_matrix) ) trajectory_gradient = jnp.sum( acceptance_probabilities * trajectory_gradients, where=~is_divergent From 8c2232cbd549ebf41ef473f9eee8060e5c80f90b Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Thu, 30 Nov 2023 08:14:52 +0100 Subject: [PATCH 3/4] Improve typing and doc of proposal generation (#594) --- blackjax/mcmc/ghmc.py | 12 +++++++----- blackjax/mcmc/proposal.py | 27 ++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index 5c71df451..818edd9de 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -156,14 +156,16 @@ def kernel( integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad ) + # Note that ghmc use nonreversible_slice_sampling, which overloads the pattern + # of SampleProposal and do not actually return the acceptance rate. proposal, info = proposal_generator(slice, integrator_state) proposal = hmc.flip_momentum(proposal) state = GHMCState( - proposal.position, - proposal.momentum, - proposal.logdensity, - proposal.logdensity_grad, - info.acceptance_rate, + position=proposal.position, + momentum=proposal.momentum, + logdensity=proposal.logdensity, + logdensity_grad=proposal.logdensity_grad, + slice=info.acceptance_rate, ) return state, info diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 9415438b0..bcb124b7b 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, NamedTuple +from typing import Callable, NamedTuple, Protocol import jax import jax.numpy as jnp +from blackjax.types import Array, PRNGKey + TrajectoryState = NamedTuple @@ -160,12 +162,21 @@ def update( return new, update +class SampleProposal(Protocol): + def __call__( + self, rng_key: Array, proposal: Proposal, new_proposal: Proposal + ) -> Proposal: + ... + + # -------------------------------------------------------------------- # STATIC SAMPLING # -------------------------------------------------------------------- -def static_binomial_sampling(rng_key, proposal, new_proposal): +def static_binomial_sampling( + rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal +) -> Proposal: """Accept or reject a proposal. In the static setting, the probability with which the new proposal is @@ -195,7 +206,9 @@ def static_binomial_sampling(rng_key, proposal, new_proposal): # -------------------------------------------------------------------- -def progressive_uniform_sampling(rng_key, proposal, new_proposal): +def progressive_uniform_sampling( + rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal +) -> Proposal: # Using expit to compute exp(w1) / (exp(w0) + exp(w1)) p_accept = jax.scipy.special.expit(new_proposal.weight - proposal.weight) do_accept = jax.random.bernoulli(rng_key, p_accept) @@ -222,7 +235,9 @@ def progressive_uniform_sampling(rng_key, proposal, new_proposal): ) -def progressive_biased_sampling(rng_key, proposal, new_proposal): +def progressive_biased_sampling( + rng_key: PRNGKey, proposal: Proposal, new_proposal: Proposal +) -> Proposal: """Baised proposal sampling :cite:p:`betancourt2017conceptual`. Unlike uniform sampling, biased sampling favors new proposals. It thus @@ -259,7 +274,9 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal): # -------------------------------------------------------------------- -def nonreversible_slice_sampling(slice, proposal, new_proposal): +def nonreversible_slice_sampling( + slice: Array, proposal: Proposal, new_proposal: Proposal +) -> Proposal: """Slice sampling for non-reversible Metropolis-Hasting update. Performs a non-reversible update of a uniform [0, 1] value From 9713452ff92767aec327c246f5be1695267fb492 Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Thu, 30 Nov 2023 14:49:14 +0100 Subject: [PATCH 4/4] Update badges in README.md --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 44b9c62c7..aac0e6626 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # BlackJAX -![CI](https://github.com/blackjax-devs/blackjax/workflows/Run%20tests/badge.svg?branch=main) -[![codecov](https://codecov.io/gh/blackjax-devs/blackjax/branch/main/graph/badge.svg)](https://codecov.io/gh/blackjax-devs/blackjax) +![Continuous integration](https://github.com/blackjax-devs/blackjax/actions/workflows/test.yml/badge.svg) +![codecov](https://codecov.io/gh/blackjax-devs/blackjax/branch/main/graph/badge.svg) +![PyPI version](https://img.shields.io/pypi/v/blackjax) ## What is BlackJAX?