Skip to content

Commit

Permalink
test CI: add static tests
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 21, 2025
1 parent 2db919d commit cddd134
Showing 1 changed file with 105 additions and 6 deletions.
111 changes: 105 additions & 6 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def run_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -178,11 +179,12 @@ def run_adjusted_mclmc_dynamic(
logdensity_fn=logdensity_fn,
)

target_acc_rate = 0.65
target_acc_rate = 0.9

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -219,6 +221,74 @@ def run_adjusted_mclmc_dynamic(

return out

def run_adjusted_mclmc(
self,
logdensity_fn,
num_steps,
initial_position,
key,
diagonal_preconditioning=False,
):
integrator = isokinetic_mclachlan

init_key, tune_key, run_key = jax.random.split(key, 3)

initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position,
logdensity_fn=logdensity_fn,
)

kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel(
integrator=integrator,
inverse_mass_matrix=inverse_mass_matrix,
logdensity_fn=logdensity_fn,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
num_integration_steps=avg_num_integration_steps,
)

target_acc_rate = 0.9

(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
num_tuning_integrator_steps,
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
diagonal_preconditioning=diagonal_preconditioning,
)

step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L

alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
num_integration_steps=L / step_size,
integrator=integrator,
inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix,
)

_, out = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=lambda state, _: state.position,
progress_bar=False,
)

return out

@parameterized.parameters(
itertools.product(
regression_test_cases, [True, False], window_adaptation_filters
Expand Down Expand Up @@ -335,7 +405,11 @@ def test_mclmc(self):
np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1)

def test_adjusted_mclmc(self):
@parameterized.parameters([True, False])
def test_adjusted_mclmc_dynamic(
self,
diagonal_preconditioning,
):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
Expand All @@ -352,6 +426,34 @@ def test_adjusted_mclmc(self):
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
scale_samples = np.exp(states["log_scale"][3000:])

np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2)
np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2)

@parameterized.parameters([True, False])
def test_adjusted_mclmc(self, diagonal_preconditioning):
"""Test the MCLMC kernel."""

init_key0, init_key1, inference_key = jax.random.split(self.key, 3)
x_data = jax.random.normal(init_key0, shape=(1000, 1))
y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape)

logposterior_fn_ = functools.partial(
self.regression_logprob, x=x_data, preds=y_data
)
logdensity_fn = lambda x: logposterior_fn_(**x)

states = self.run_adjusted_mclmc(
initial_position={"coefs": 1.0, "log_scale": 1.0},
logdensity_fn=logdensity_fn,
key=inference_key,
num_steps=10000,
diagonal_preconditioning=diagonal_preconditioning,
)

coefs_samples = states["coefs"][3000:]
Expand Down Expand Up @@ -417,10 +519,7 @@ def get_inverse_mass_matrix():
inverse_mass_matrix=inverse_mass_matrix,
)

(
_,
blackjax_mclmc_sampler_params,
) = blackjax.mclmc_find_L_and_step_size(
(_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
Expand Down

0 comments on commit cddd134

Please sign in to comment.