Skip to content

Commit

Permalink
Merge pull request #707 from tfaod/patch-1
Browse files Browse the repository at this point in the history
Fix typo in `jax_nadamw_target_setting.py`
priyakasimbeg authored Mar 18, 2024

Verified

This commit was signed with the committer’s verified signature.
gastaldi George Gastaldi
2 parents 39f34fa + 2f5d961 commit 0b514ef
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ def init_optimizer_state(workload: spec.Workload,

def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters['warmup_factor * step_hint'])
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
warmup_fn = optax.linear_schedule(
init_value=0.,
end_value=hyperparameters['learning_rate'],

0 comments on commit 0b514ef

Please sign in to comment.