Skip to content

Commit

Permalink
Add WMT target setting runs
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Aug 23, 2022
1 parent 2643718 commit f70243f
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 6 deletions.
87 changes: 87 additions & 0 deletions target_setting_runs/wmt/jax_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Jax submission for the target-setting run on WMT with AdamW."""

import functools
from typing import Dict, List, Tuple

import jax
import jax.numpy as jnp
import optax

from algorithmic_efficiency import spec
from target_setting_runs.data_selection import \
data_selection # pylint: disable=unused-import
from target_setting_runs.jax_adamw import \
init_optimizer_state # pylint: disable=unused-import


def get_batch_size(workload_name):
# Return the global batch size.
del workload_name
return 256


@functools.partial(
jax.pmap,
in_axes=(None, None, 0, 0, 0, 0),
axis_name='batch',
static_broadcasted_argnums=(0, 1))
def pmapped_train_step(workload,
opt_update_fn,
optimizer_state,
current_param_container,
batch,
dropout_rng):
"""Perform a single training step."""

def _loss_fn(params):
"""Loss function used for training."""
logits, _ = workload.model_fn(
params,
batch,
model_state=None,
mode=spec.ForwardPassMode.TRAIN,
rng=dropout_rng,
update_batch_norm=False)
targets = batch['targets']
weights = jnp.where(targets > 0, 1.0, 0.0)
loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum()
return loss

grad_fn = jax.value_and_grad(_loss_fn)
_, grad = grad_fn(current_param_container)
grad = jax.lax.pmean(grad, axis_name='batch')
updates, new_optimizer_state = opt_update_fn(
grad, optimizer_state, current_param_container)
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del eval_results
del global_step
del model_state
del hyperparameters
del loss_type

optimizer_state, opt_update_fn = optimizer_state
dropout_rngs = jax.random.split(rng, jax.local_device_count())
new_optimizer_state, updated_params = pmapped_train_step(
workload,
opt_update_fn,
optimizer_state,
current_param_container,
batch,
dropout_rngs)
return (new_optimizer_state, opt_update_fn), updated_params, None
58 changes: 58 additions & 0 deletions target_setting_runs/wmt/pytorch_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""PyTorch submission for the target-setting run on WMT with AdamW."""

from typing import Dict, List, Tuple

import torch

from algorithmic_efficiency import spec
from target_setting_runs.data_selection import \
data_selection # pylint: disable=unused-import
from target_setting_runs.pytorch_adamw import \
init_optimizer_state # pylint: disable=unused-import


def get_batch_size(workload_name):
# Return the global batch size.
del workload_name
return 256


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del eval_results
del loss_type
del hyperparameters
del global_step

current_model = current_param_container
current_model.train()
optimizer_state['optimizer'].zero_grad()

logits, _ = workload.model_fn(
params=current_model,
augmented_and_preprocessed_input_batch=batch,
model_state=model_state,
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=False)

targets = batch['targets']
weights = torch.where(targets > 0, 1.0, 0.0)
loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum()
loss.backward()

optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()

return (optimizer_state, current_param_container, None)
17 changes: 11 additions & 6 deletions target_setting_runs/wmt/tuning_search_space.json
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
{
"learning_rate": {
"feasible_points": [
0.000487
0.000844
]
},
"beta1": {
"feasible_points": [
0.8194
0.8895
]
},
"beta2": {
"feasible_points": [
0.9803
0.9978
]
},
"warmup_epochs": {
"warmup_steps": {
"feasible_points": [
0.02
1000
]
},
"num_steps": {
"feasible_points": [
50000
]
},
"l2": {
"feasible_points": [
0.407336
0.081354
]
}
}

0 comments on commit f70243f

Please sign in to comment.