From 2f5d9619a1d07984652d099738d75f7b71ddf1ac Mon Sep 17 00:00:00 2001 From: Alice <8447104+tfaod@users.noreply.github.com> Date: Sat, 16 Mar 2024 17:38:51 -0400 Subject: [PATCH] Fix typo in `jax_nadamw_target_setting.py` Fix dict indexing typo in `jax_nadamw_target_setting.py` --- .../self_tuning/jax_nadamw_target_setting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 6c859b8dd..9ed09a615 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -180,7 +180,7 @@ def init_optimizer_state(workload: spec.Workload, def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. - warmup_steps = int(hyperparameters['warmup_factor * step_hint']) + warmup_steps = int(hyperparameters['warmup_factor'] * step_hint) warmup_fn = optax.linear_schedule( init_value=0., end_value=hyperparameters['learning_rate'],