diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py index b35750086..c54202e56 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py index 190720213..dd42743e2 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) @@ -192,7 +193,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return schedule_fn # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters) + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, b1=1.0 - hyperparameters.one_minus_beta1, diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py index a1cf612f2..57da48167 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1209abadc..ef6e84c94 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -230,7 +231,7 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) return optimizer_state @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad()