Skip to content

Commit

Permalink
clean up tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Dec 1, 2023
1 parent 72d70c6 commit c121beb
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 80 deletions.
175 changes: 108 additions & 67 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@

import jax
import jax.numpy as jnp

Check warning on line 21 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L20-L21

Added lines #L20 - L21 were not covered by tests
from blackjax.diagnostics import effective_sample_size #type: ignore
import jax
import jax.numpy as jnp
from typing import NamedTuple
from typing import NamedTuple

from blackjax.diagnostics import effective_sample_size # type: ignore
from blackjax.util import pytree_size

Check warning on line 24 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L23-L24

Added lines #L23 - L24 were not covered by tests


Expand All @@ -35,13 +31,21 @@ class MCLMCAdaptationState(NamedTuple):
L (float): The momentum decoherent rate for the MCLMC algorithm.
step_size (float): The step size used for the MCLMC algorithm.
"""

L: float
step_size: float

Check warning on line 36 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L35-L36

Added lines #L35 - L36 were not covered by tests

def mclmc_find_L_and_step_size(kernel, num_steps, state, part1_key, part2_key, frac_tune1=0.1,

def mclmc_find_L_and_step_size(

Check warning on line 39 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L39

Added line #L39 was not covered by tests
kernel,
num_steps,
state,
part1_key,
part2_key,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1, ):
frac_tune3=0.1,
):
"""
Finds the optimal value of L (step size) for the MCLMC algorithm.
Expand All @@ -61,123 +65,160 @@ def mclmc_find_L_and_step_size(kernel, num_steps, state, part1_key, part2_key, f
params = MCLMCAdaptationState(jnp.sqrt(dim), jnp.sqrt(dim) * 0.25)
varEwanted = 5e-4

Check warning on line 66 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L64-L66

Added lines #L64 - L66 were not covered by tests

state, params = make_L_step_size_adaptation(kernel=kernel, dim=dim, frac_tune1=frac_tune1, frac_tune2=frac_tune2, varEwanted=varEwanted, sigma_xi=1.5, num_effective_samples=150)(state, params, num_steps, part1_key)
state, params = make_L_step_size_adaptation(

Check warning on line 68 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L68

Added line #L68 was not covered by tests
kernel=kernel,
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=frac_tune2,
varEwanted=varEwanted,
sigma_xi=1.5,
num_effective_samples=150,
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(state,params, num_steps, part2_key)

return state, params
state, params = make_adaptation_L(kernel, frac=frac_tune3, Lfactor=0.4)(

Check warning on line 79 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L78-L79

Added lines #L78 - L79 were not covered by tests
state, params, num_steps, part2_key
)

return state, params

Check warning on line 83 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L83

Added line #L83 was not covered by tests


def make_L_step_size_adaptation(kernel, dim, frac_tune1, frac_tune2,
varEwanted = 1e-3, sigma_xi = 1.5, num_effective_samples = 150):
def make_L_step_size_adaptation(

Check warning on line 86 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L86

Added line #L86 was not covered by tests
kernel,
dim,
frac_tune1,
frac_tune2,
varEwanted=1e-3,
sigma_xi=1.5,
num_effective_samples=150,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""

gamma_forget = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

Check warning on line 97 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L97

Added line #L97 was not covered by tests



def predictor(state_old, params, adaptive_state, rng_key):

Check warning on line 99 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L99

Added line #L99 was not covered by tests
"""does one step with the dynamics and updates the prediction for the optimal stepsize
Designed for the unadjusted MCHMC"""
Designed for the unadjusted MCHMC"""

W, F, step_size_max = adaptive_state

Check warning on line 103 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L103

Added line #L103 was not covered by tests

# dynamics
state_new, info = kernel(rng_key = rng_key, state=state_old, L=params.L, step_size=params.step_size)
state_new, info = kernel(

Check warning on line 106 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L106

Added line #L106 was not covered by tests
rng_key=rng_key, state=state_old, L=params.L, step_size=params.step_size
)
energy_change = info.dE

Check warning on line 109 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L109

Added line #L109 was not covered by tests
# step updating
success, state, step_size_max, energy_change = handle_nans(state_old,state_new,
params.step_size, step_size_max, energy_change)
success, state, step_size_max, energy_change = handle_nans(

Check warning on line 111 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L111

Added line #L111 was not covered by tests
state_old, state_new, params.step_size, step_size_max, energy_change
)


# Warning: var = 0 if there were nans, but we will give it a very small weight
xi = (jnp.square(energy_change) / (dim * varEwanted)) + 1e-8 # 1e-8 is added to avoid divergences in log xi
w = jnp.exp(-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one.

F = gamma_forget * F + w * (xi/jnp.power(params.step_size, 6.0))
xi = (

Check warning on line 116 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L116

Added line #L116 was not covered by tests
jnp.square(energy_change) / (dim * varEwanted)
) + 1e-8 # 1e-8 is added to avoid divergences in log xi
w = jnp.exp(

Check warning on line 119 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L119

Added line #L119 was not covered by tests
-0.5 * jnp.square(jnp.log(xi) / (6.0 * sigma_xi))
) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one.

F = gamma_forget * F + w * (xi / jnp.power(params.step_size, 6.0))
W = gamma_forget * W + w
step_size = jnp.power(F/W, -1.0/6.0) #We use the Var[E] = O(eps^6) relation here.
step_size = (step_size < step_size_max) * step_size + (step_size > step_size_max) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
step_size = jnp.power(

Check warning on line 125 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L123-L125

Added lines #L123 - L125 were not covered by tests
F / W, -1.0 / 6.0
) # We use the Var[E] = O(eps^6) relation here.
step_size = (step_size < step_size_max) * step_size + (

Check warning on line 128 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L128

Added line #L128 was not covered by tests
step_size > step_size_max
) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences
params_new = params._replace(step_size=step_size)

Check warning on line 131 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L131

Added line #L131 was not covered by tests

return state, params_new, params_new, (W, F, step_size_max), success

return state, params_new, params_new, (W, F, step_size_max), success

Check warning on line 133 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L133

Added line #L133 was not covered by tests

def update_kalman(x, state, outer_weight, success, step_size):

Check warning on line 135 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L135

Added line #L135 was not covered by tests
"""kalman filter to estimate the size of the posterior"""
W, F1, F2 = state
w = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
F1 = (W*F1 + w*x) / (W + w + zero_prevention) # Update <f(x)> with a Kalman filter
F2 = (W*F2 + w*jnp.square(x)) / (W + w + zero_prevention) # Update <f(x)> with a Kalman filter
F1 = (W * F1 + w * x) / (

Check warning on line 140 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L137-L140

Added lines #L137 - L140 were not covered by tests
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
F2 = (W * F2 + w * jnp.square(x)) / (

Check warning on line 143 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L143

Added line #L143 was not covered by tests
W + w + zero_prevention
) # Update <f(x)> with a Kalman filter
W += w
return (W, F1, F2)

Check warning on line 147 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L146-L147

Added lines #L146 - L147 were not covered by tests

adap0 = (0.0, 0.0, jnp.inf)

Check warning on line 149 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L149

Added line #L149 was not covered by tests

adap0 = (0., 0., jnp.inf)


def step(iteration_state, weight_and_key):
outer_weight, rng_key = weight_and_key
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
state, params, adaptive_state, kalman_state = iteration_state
state, params, params_final, adaptive_state, success = predictor(state, params, adaptive_state, rng_key)
kalman_state = update_kalman(state.position, kalman_state, outer_weight, success, params.step_size)
state, params, params_final, adaptive_state, success = predictor(

Check warning on line 155 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L151-L155

Added lines #L151 - L155 were not covered by tests
state, params, adaptive_state, rng_key
)
kalman_state = update_kalman(

Check warning on line 158 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L158

Added line #L158 was not covered by tests
state.position, kalman_state, outer_weight, success, params.step_size
)

return (state, params_final, adaptive_state, kalman_state), None

Check warning on line 162 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L162

Added line #L162 was not covered by tests


def L_step_size_adaptation(state, params, num_steps, rng_key):

num_steps1, num_steps2 = int(num_steps * frac_tune1), int(num_steps*frac_tune2)
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(

Check warning on line 165 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L164-L165

Added lines #L164 - L165 were not covered by tests
num_steps * frac_tune2
)
# TODO: change below to use jax.random.split
L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2) )
L_step_size_adaptation_keys = jnp.array([rng_key] * (num_steps1 + num_steps2))

Check warning on line 169 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L169

Added line #L169 was not covered by tests

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

Check warning on line 172 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L172

Added line #L172 was not covered by tests

#initial state of the kalman filter
kalman_state = (0., jnp.zeros(dim), jnp.zeros(dim))
# initial state of the kalman filter
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))

Check warning on line 175 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L175

Added line #L175 was not covered by tests

# run the steps
kalman_state = jax.lax.scan(

Check warning on line 178 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L178

Added line #L178 was not covered by tests
step,
init= (state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys), length= num_steps1 + num_steps2)[0]
step,
init=(state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
)[0]
state, params, _, kalman_state_output = kalman_state

Check warning on line 184 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L184

Added line #L184 was not covered by tests

L = params.L

Check warning on line 186 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L186

Added line #L186 was not covered by tests
# determine L
if num_steps2 != 0.:
if num_steps2 != 0.0:
_, F1, F2 = kalman_state_output
variances = F2 - jnp.square(F1)
L = jnp.sqrt(jnp.sum(variances))

Check warning on line 191 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L188-L191

Added lines #L188 - L191 were not covered by tests


return state, MCLMCAdaptationState(L, params.step_size)

Check warning on line 193 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L193

Added line #L193 was not covered by tests

return L_step_size_adaptation

Check warning on line 195 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L195

Added line #L195 was not covered by tests


def make_adaptation_L(kernel, frac, Lfactor):

Check warning on line 198 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L198

Added line #L198 was not covered by tests
"""determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)"""

def adaptation_L(state, params, num_steps, key):

def adaptation_L(state, params, num_steps, key):
num_steps = int(num_steps * frac)

Check warning on line 202 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L201-L202

Added lines #L201 - L202 were not covered by tests
# TODO: change below to use jax.random.split
adaptation_L_keys = jnp.array([key]*num_steps)
adaptation_L_keys = jnp.array([key] * num_steps)

Check warning on line 204 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L204

Added line #L204 was not covered by tests

# run kernel in the normal way
state, info = jax.lax.scan(

Check warning on line 207 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L207

Added line #L207 was not covered by tests
f=lambda s, k: (kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)),
init=state,
xs=adaptation_L_keys)
samples = info.transformed_x # tranform is the identity here
ESS = 0.5 * effective_sample_size(jnp.array([samples, samples])) # TODO: should only use a single chain here

return state, params._replace(L=Lfactor * params.step_size * jnp.average(num_steps / ESS))
f=lambda s, k: (
kernel(rng_key=k, state=s, L=params.L, step_size=params.step_size)
),
init=state,
xs=adaptation_L_keys,
)
samples = info.transformed_x # tranform is the identity here
ESS = 0.5 * effective_sample_size(

Check warning on line 215 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L214-L215

Added lines #L214 - L215 were not covered by tests
jnp.array([samples, samples])
) # TODO: should only use a single chain here

return state, params._replace(

Check warning on line 219 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L219

Added line #L219 was not covered by tests
L=Lfactor * params.step_size * jnp.average(num_steps / ESS)
)

return adaptation_L

Check warning on line 223 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L223

Added line #L223 was not covered by tests

Expand All @@ -187,10 +228,10 @@ def handle_nans(state_old, state_new, step_size, step_size_max, kinetic_change):

reduced_step_size = 0.8
nonans = jnp.all(jnp.isfinite(state_new.position))
state, step_size, kinetic_change = jax.tree_util.tree_map(lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(state_new, step_size_max, kinetic_change),
(state_old, step_size * reduced_step_size, 0.))

state, step_size, kinetic_change = jax.tree_util.tree_map(

Check warning on line 231 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L229-L231

Added lines #L229 - L231 were not covered by tests
lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old),
(state_new, step_size_max, kinetic_change),
(state_old, step_size * reduced_step_size, 0.0),
)

return nonans, state, step_size, kinetic_change

Check warning on line 237 in blackjax/adaptation/mclmc_adaptation.py

View check run for this annotation

Codecov / codecov/patch

blackjax/adaptation/mclmc_adaptation.py#L237

Added line #L237 was not covered by tests


5 changes: 0 additions & 5 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Step size adaptation"""
import warnings
from typing import Callable, NamedTuple

import jax
import jax.numpy as jnp
from scipy.fft import next_fast_len

from blackjax.mcmc.hmc import HMCState
from blackjax.mcmc.integrators import noneuclidean_mclachlan
from blackjax.mcmc.mclmc import IntegratorState, build_kernel, init
from blackjax.optimizers.dual_averaging import dual_averaging
from blackjax.types import PRNGKey

Expand Down Expand Up @@ -261,4 +257,3 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState:
rss_state = jax.lax.while_loop(do_continue, update, rss_state)

return rss_state.step_size

1 change: 0 additions & 1 deletion blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,3 @@ def noneuclidean_integrator(
noneuclidean_leapfrog = generate_noneuclidean_integrator(velocity_verlet_cofficients)
noneuclidean_yoshida = generate_noneuclidean_integrator(yoshida_cofficients)
noneuclidean_mclachlan = generate_noneuclidean_integrator(mclachlan_cofficients)

15 changes: 9 additions & 6 deletions blackjax/mcmc/mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ class MCLMCInfo(NamedTuple):
def init(x_initial: ArrayLike, logdensity_fn, rng_key):
l, g = jax.value_and_grad(logdensity_fn)(x_initial)

Check warning on line 48 in blackjax/mcmc/mclmc.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/mclmc.py#L48

Added line #L48 was not covered by tests

jax.debug.print("thing blackjax {x}", x=IntegratorState(
position=x_initial,
momentum=generate_unit_vector(rng_key, x_initial),
logdensity=l,
logdensity_grad=g,
))
jax.debug.print(

Check warning on line 50 in blackjax/mcmc/mclmc.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/mclmc.py#L50

Added line #L50 was not covered by tests
"thing blackjax {x}",
x=IntegratorState(
position=x_initial,
momentum=generate_unit_vector(rng_key, x_initial),
logdensity=l,
logdensity_grad=g,
),
)

return IntegratorState(

Check warning on line 60 in blackjax/mcmc/mclmc.py

View check run for this annotation

Codecov / codecov/patch

blackjax/mcmc/mclmc.py#L60

Added line #L60 was not covered by tests
position=x_initial,
Expand Down
2 changes: 1 addition & 1 deletion blackjax/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Utility functions for BlackJax."""
from functools import partial
from typing import Union
import jax

import jax.numpy as jnp
from jax import jit, lax
Expand Down Expand Up @@ -104,6 +103,7 @@ def generate_unit_vector(
sample = normal(rng_key, shape=p.shape, dtype=p.dtype)
return unravel_fn(sample / jnp.linalg.norm(sample))


def partially_refresh_momentum(momentum, rng_key, step_size, L):
"""Adds a small noise to momentum and normalizes.
Expand Down

0 comments on commit c121beb

Please sign in to comment.