From cddd134c57ad328fedd0c01c74c4b57fb48be62a Mon Sep 17 00:00:00 2001 From: = Date: Tue, 21 Jan 2025 12:54:36 +0100 Subject: [PATCH] test CI: add static tests --- tests/mcmc/test_sampling.py | 111 ++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 6 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d081462da..7540de767 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -122,6 +122,7 @@ def run_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -178,11 +179,12 @@ def run_adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, ) - target_acc_rate = 0.65 + target_acc_rate = 0.9 ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -219,6 +221,74 @@ def run_adjusted_mclmc_dynamic( return out + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=avg_num_integration_steps, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + num_tuning_integrator_steps, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -335,7 +405,11 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) - def test_adjusted_mclmc(self): + @parameterized.parameters([True, False]) + def test_adjusted_mclmc_dynamic( + self, + diagonal_preconditioning, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -352,6 +426,34 @@ def test_adjusted_mclmc(self): logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + @parameterized.parameters([True, False]) + def test_adjusted_mclmc(self, diagonal_preconditioning): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, ) coefs_samples = states["coefs"][3000:] @@ -417,10 +519,7 @@ def get_inverse_mass_matrix(): inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state,