Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Nov 21, 2023
1 parent 0338f8f commit 5f83404
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
_GRAD_CLIP_EPS = 1e-6

HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}


# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
Expand Down Expand Up @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload,
del rng
del hyperparameters

hyperparameters=HPARAMS
hyperparameters = HPARAMS

def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
_GRAD_CLIP_EPS = 1e-6

HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}


# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
Expand Down Expand Up @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload,
del rng
del hyperparameters

hyperparameters=HPARAMS
hyperparameters = HPARAMS

def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
Expand All @@ -192,7 +193,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
return schedule_fn

# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters)
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
learning_rate=lr_schedule_fn,
b1=1.0 - hyperparameters.one_minus_beta1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

USE_PYTORCH_DDP = pytorch_setup()[0]

HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}
HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}


# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
class NAdamW(torch.optim.Optimizer):
Expand Down Expand Up @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload,
del hyperparameters

hyperparameters = HPARAMS

current_model = current_param_container
current_model.train()
optimizer_state['optimizer'].zero_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

USE_PYTORCH_DDP = pytorch_setup()[0]

HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}
HPARAMS = {
"dropout_rate": 0.1,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
"weight_decay": 0.08121616522670176,
"warmup_factor": 0.02
}


# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
class NAdamW(torch.optim.Optimizer):
Expand Down Expand Up @@ -230,7 +231,7 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])

optimizer_state['scheduler'] = pytorch_cosine_warmup(
workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer'])
workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])

return optimizer_state

Expand All @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload,
del hyperparameters

hyperparameters = HPARAMS

current_model = current_param_container
current_model.train()
optimizer_state['optimizer'].zero_grad()
Expand Down

0 comments on commit 5f83404

Please sign in to comment.