Skip to content

Commit

Permalink
Merge branch 'main' into ciguaran_pretuning_adaptive_test
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran authored Jan 22, 2025
2 parents 11088a4 + 4d4eae0 commit 92a7c05
Show file tree
Hide file tree
Showing 13 changed files with 522 additions and 158 deletions.
18 changes: 0 additions & 18 deletions .github/workflows/schedule-meeting.yml

This file was deleted.

2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import elliptical_slice as _elliptical_slice
Expand Down Expand Up @@ -113,6 +114,7 @@ def generate_top_level_api_from(module):
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic)
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
Expand Down
73 changes: 52 additions & 21 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,20 @@ def adjusted_mclmc_find_L_and_step_size(
dim = pytree_size(state.position)
if params is None:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,))
)

part1_key, part2_key = jax.random.split(rng_key, 2)

total_num_tuning_integrator_steps = 0
for i in range(num_windows):
window_key = jax.random.fold_in(part1_key, i)
(state, params, eigenvector) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
eigenvector,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -90,22 +96,38 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, window_key)
)(
state, params, num_steps, window_key
)
total_num_tuning_integrator_steps += num_tuning_integrator_steps

if frac_tune3 != 0:
for i in range(num_windows):
part2_key = jax.random.fold_in(part2_key, i)
part2_key1, part2_key2 = jax.random.split(part2_key, 2)

state, params = adjusted_mclmc_make_adaptation_L(
(
state,
params,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_adaptation_L(
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
max=max,
eigenvector=eigenvector,
)(state, params, num_steps, part2_key1)
)(
state, params, num_steps, part2_key1
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

(state, params, _) = adjusted_mclmc_make_L_step_size_adaptation(
(
state,
params,
_,
num_tuning_integrator_steps,
) = adjusted_mclmc_make_L_step_size_adaptation(
kernel=mclmc_kernel,
dim=dim,
frac_tune1=frac_tune1,
Expand All @@ -115,9 +137,13 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
)(state, params, num_steps, part2_key2)
)(
state, params, num_steps, part2_key2
)

total_num_tuning_integrator_steps += num_tuning_integrator_steps

return state, params
return state, params, total_num_tuning_integrator_steps


def adjusted_mclmc_make_L_step_size_adaptation(
Expand Down Expand Up @@ -152,7 +178,7 @@ def step(iteration_state, weight_and_key):
state=previous_state,
avg_num_integration_steps=avg_num_integration_steps,
step_size=params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)

# step updating
Expand Down Expand Up @@ -256,6 +282,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
update_da=update_da,
)

num_tuning_integrator_steps = info.num_integration_steps.sum()
final_stepsize = final_da(dual_avg_state)
params = params._replace(step_size=final_stepsize)

Expand Down Expand Up @@ -283,9 +310,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L=params.L * change, step_size=params.step_size * change
)
if diagonal_preconditioning:
params = params._replace(
sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim)
)
params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim))

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
(
Expand All @@ -301,9 +326,11 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
initial_da=initial_da,
)

num_tuning_integrator_steps += info.num_integration_steps.sum()

params = params._replace(step_size=final_da(dual_avg_state))

return state, params, eigenvector
return state, params, eigenvector, num_tuning_integrator_steps

return L_step_size_adaptation

Expand All @@ -318,16 +345,16 @@ def adaptation_L(state, params, num_steps, key):
adaptation_L_keys = jax.random.split(key, num_steps)

def step(state, key):
next_state, _ = kernel(
next_state, info = kernel(
rng_key=key,
state=state,
step_size=params.step_size,
avg_num_integration_steps=params.L / params.step_size,
sqrt_diag_cov=params.sqrt_diag_cov,
inverse_mass_matrix=params.inverse_mass_matrix,
)
return next_state, next_state.position
return next_state, (next_state.position, info)

state, samples = jax.lax.scan(
state, (samples, info) = jax.lax.scan(
f=step,
init=state,
xs=adaptation_L_keys,
Expand All @@ -348,10 +375,14 @@ def step(state, key):
# number of effective samples per 1 actual sample
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
return (
state,
params._replace(
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
),
info.num_integration_steps.sum(),
)

return adaptation_L
Expand Down
30 changes: 15 additions & 15 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple):
The momentum decoherent rate for the MCLMC algorithm.
step_size
The step size used for the MCLMC algorithm.
sqrt_diag_cov
inverse_mass_matrix
A matrix used for preconditioning.
"""

L: float
step_size: float
sqrt_diag_cov: float
inverse_mass_matrix: float


def mclmc_find_L_and_step_size(
Expand Down Expand Up @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size(
Example
-------
.. code::
kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel(
kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
)
(
Expand All @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size(
"""
dim = pytree_size(state.position)
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,))
jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,))
)
part1_key, part2_key = jax.random.split(rng_key, 2)

Expand All @@ -123,10 +123,10 @@ def mclmc_find_L_and_step_size(

if frac_tune3 != 0:
state, params = make_adaptation_L(
mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4
mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4
)(state, params, num_steps, part2_key)

return state, params
return state, params, num_steps * (frac_tune1 + frac_tune2 + frac_tune3)


def make_L_step_size_adaptation(
Expand All @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
rng_key, nan_key = jax.random.split(rng_key)

# dynamics
next_state, info = kernel(params.sqrt_diag_cov)(
next_state, info = kernel(params.inverse_mass_matrix)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

L = params.L
# determine L
sqrt_diag_cov = params.sqrt_diag_cov
inverse_mass_matrix = params.inverse_mass_matrix
if num_steps2 > 1:
x_average, x_squared_average = average[0], average[1]
variances = x_squared_average - jnp.square(x_average)
L = jnp.sqrt(jnp.sum(variances))

if diagonal_preconditioning:
sqrt_diag_cov = jnp.sqrt(variances)
params = params._replace(sqrt_diag_cov=sqrt_diag_cov)
inverse_mass_matrix = variances
params = params._replace(inverse_mass_matrix=inverse_mass_matrix)
L = jnp.sqrt(dim)

# readjust the stepsize
Expand All @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
xs=(jnp.ones(steps), keys), state=state, params=params
)

return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov)
return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix)

return L_step_size_adaptation

Expand All @@ -274,8 +274,8 @@ def make_adaptation_L(kernel, frac, Lfactor):
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps)
num_steps_3 = int(num_steps * frac)
adaptation_L_keys = jax.random.split(key, num_steps_3)

def step(state, key):
next_state, _ = kernel(
Expand All @@ -297,7 +297,7 @@ def step(state, key):
ess = effective_sample_size(flat_samples[None, ...])

return state, params._replace(
L=Lfactor * params.step_size * jnp.mean(num_steps / ess)
L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess)
)

return adaptation_L
Expand Down
2 changes: 2 additions & 0 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import (
adjusted_mclmc,
adjusted_mclmc_dynamic,
barker,
elliptical_slice,
ghmc,
Expand All @@ -25,5 +26,6 @@
"marginal_latent_gaussian",
"random_walk",
"mclmc",
"adjusted_mclmc_dynamic",
"adjusted_mclmc",
]
Loading

0 comments on commit 92a7c05

Please sign in to comment.