-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
156 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Jax submission for the target-setting run on WMT with AdamW.""" | ||
|
||
import functools | ||
from typing import Dict, List, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
|
||
from algorithmic_efficiency import spec | ||
from target_setting_runs.data_selection import \ | ||
data_selection # pylint: disable=unused-import | ||
from target_setting_runs.jax_adamw import \ | ||
init_optimizer_state # pylint: disable=unused-import | ||
|
||
|
||
def get_batch_size(workload_name): | ||
# Return the global batch size. | ||
del workload_name | ||
return 256 | ||
|
||
|
||
@functools.partial( | ||
jax.pmap, | ||
in_axes=(None, None, 0, 0, 0, 0), | ||
axis_name='batch', | ||
static_broadcasted_argnums=(0, 1)) | ||
def pmapped_train_step(workload, | ||
opt_update_fn, | ||
optimizer_state, | ||
current_param_container, | ||
batch, | ||
dropout_rng): | ||
"""Perform a single training step.""" | ||
|
||
def _loss_fn(params): | ||
"""Loss function used for training.""" | ||
logits, _ = workload.model_fn( | ||
params, | ||
batch, | ||
model_state=None, | ||
mode=spec.ForwardPassMode.TRAIN, | ||
rng=dropout_rng, | ||
update_batch_norm=False) | ||
targets = batch['targets'] | ||
weights = jnp.where(targets > 0, 1.0, 0.0) | ||
loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum() | ||
return loss | ||
|
||
grad_fn = jax.value_and_grad(_loss_fn) | ||
_, grad = grad_fn(current_param_container) | ||
grad = jax.lax.pmean(grad, axis_name='batch') | ||
updates, new_optimizer_state = opt_update_fn( | ||
grad, optimizer_state, current_param_container) | ||
updated_params = optax.apply_updates(current_param_container, updates) | ||
return new_optimizer_state, updated_params | ||
|
||
|
||
def update_params(workload: spec.Workload, | ||
current_param_container: spec.ParameterContainer, | ||
current_params_types: spec.ParameterTypeTree, | ||
model_state: spec.ModelAuxiliaryState, | ||
hyperparameters: spec.Hyperparameters, | ||
batch: Dict[str, spec.Tensor], | ||
loss_type: spec.LossType, | ||
optimizer_state: spec.OptimizerState, | ||
eval_results: List[Tuple[int, float]], | ||
global_step: int, | ||
rng: spec.RandomState) -> spec.UpdateReturn: | ||
"""Return (updated_optimizer_state, updated_params, updated_model_state).""" | ||
del current_params_types | ||
del eval_results | ||
del global_step | ||
del model_state | ||
del hyperparameters | ||
del loss_type | ||
|
||
optimizer_state, opt_update_fn = optimizer_state | ||
dropout_rngs = jax.random.split(rng, jax.local_device_count()) | ||
new_optimizer_state, updated_params = pmapped_train_step( | ||
workload, | ||
opt_update_fn, | ||
optimizer_state, | ||
current_param_container, | ||
batch, | ||
dropout_rngs) | ||
return (new_optimizer_state, opt_update_fn), updated_params, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""PyTorch submission for the target-setting run on WMT with AdamW.""" | ||
|
||
from typing import Dict, List, Tuple | ||
|
||
import torch | ||
|
||
from algorithmic_efficiency import spec | ||
from target_setting_runs.data_selection import \ | ||
data_selection # pylint: disable=unused-import | ||
from target_setting_runs.pytorch_adamw import \ | ||
init_optimizer_state # pylint: disable=unused-import | ||
|
||
|
||
def get_batch_size(workload_name): | ||
# Return the global batch size. | ||
del workload_name | ||
return 256 | ||
|
||
|
||
def update_params(workload: spec.Workload, | ||
current_param_container: spec.ParameterContainer, | ||
current_params_types: spec.ParameterTypeTree, | ||
model_state: spec.ModelAuxiliaryState, | ||
hyperparameters: spec.Hyperparameters, | ||
batch: Dict[str, spec.Tensor], | ||
loss_type: spec.LossType, | ||
optimizer_state: spec.OptimizerState, | ||
eval_results: List[Tuple[int, float]], | ||
global_step: int, | ||
rng: spec.RandomState) -> spec.UpdateReturn: | ||
"""Return (updated_optimizer_state, updated_params, updated_model_state).""" | ||
del current_params_types | ||
del eval_results | ||
del loss_type | ||
del hyperparameters | ||
del global_step | ||
|
||
current_model = current_param_container | ||
current_model.train() | ||
optimizer_state['optimizer'].zero_grad() | ||
|
||
logits, _ = workload.model_fn( | ||
params=current_model, | ||
augmented_and_preprocessed_input_batch=batch, | ||
model_state=model_state, | ||
mode=spec.ForwardPassMode.TRAIN, | ||
rng=rng, | ||
update_batch_norm=False) | ||
|
||
targets = batch['targets'] | ||
weights = torch.where(targets > 0, 1.0, 0.0) | ||
loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum() | ||
loss.backward() | ||
|
||
optimizer_state['optimizer'].step() | ||
optimizer_state['scheduler'].step() | ||
|
||
return (optimizer_state, current_param_container, None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,32 @@ | ||
{ | ||
"learning_rate": { | ||
"feasible_points": [ | ||
0.000487 | ||
0.000844 | ||
] | ||
}, | ||
"beta1": { | ||
"feasible_points": [ | ||
0.8194 | ||
0.8895 | ||
] | ||
}, | ||
"beta2": { | ||
"feasible_points": [ | ||
0.9803 | ||
0.9978 | ||
] | ||
}, | ||
"warmup_epochs": { | ||
"warmup_steps": { | ||
"feasible_points": [ | ||
0.02 | ||
1000 | ||
] | ||
}, | ||
"num_steps": { | ||
"feasible_points": [ | ||
50000 | ||
] | ||
}, | ||
"l2": { | ||
"feasible_points": [ | ||
0.407336 | ||
0.081354 | ||
] | ||
} | ||
} |