diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 1846eb1b3..b19bad9ff 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -166,8 +166,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( num_steps * frac_tune2 ) - # L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) - L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2)) + L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2) # we use the last num_steps2 to compute the diagonal preconditioner outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) @@ -201,8 +200,7 @@ def make_adaptation_L(kernel, frac, Lfactor): def adaptation_L(state, params, num_steps, key): num_steps = int(num_steps * frac) - # adaptation_L_keys = jax.random.split(key, num_steps) - adaptation_L_keys = jnp.array([key] * (num_steps)) + adaptation_L_keys = jax.random.split(key, num_steps) # run kernel in the normal way state, info = jax.lax.scan(