diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 83823a6b9..d5aa59c95 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -20,6 +20,7 @@ jobs: pylint algorithmic_efficiency pylint baselines pylint target_setting_runs + pylint reference_submissions pylint submission_runner.py pylint tests diff --git a/reference_submissions/__init__.py b/reference_submissions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/cifar/__init__.py b/reference_submissions/cifar/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/cifar/cifar_jax/__init__.py b/reference_submissions/cifar/cifar_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/cifar/cifar_jax/submission.py b/reference_submissions/cifar/cifar_jax/submission.py new file mode 100644 index 000000000..f45566092 --- /dev/null +++ b/reference_submissions/cifar/cifar_jax/submission.py @@ -0,0 +1,151 @@ +"""Training algorithm track submission functions for CIFAR10.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 128 + + +def cosine_decay(lr, step, total_steps): + ratio = jnp.maximum(0., step / total_steps) + mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio)) + return mult * lr + + +def create_learning_rate_fn(hparams: spec.Hyperparameters, + steps_per_epoch: int): + """Create learning rate schedule.""" + base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 256. + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=base_learning_rate, + transition_steps=hparams.warmup_epochs * steps_per_epoch) + cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=base_learning_rate, + decay_steps=cosine_epochs * steps_per_epoch) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], + boundaries=[hparams.warmup_epochs * steps_per_epoch]) + return schedule_fn + + +def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): + steps_per_epoch = num_train_examples // get_batch_size('cifar') + learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) + opt_init_fn, opt_update_fn = optax.sgd( + nesterov=True, + momentum=hyperparameters.momentum, + learning_rate=learning_rate_fn) + return opt_init_fn, opt_update_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_params + del model_state + del rng + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + opt_init_fn, opt_update_fn = optimizer(hyperparameters, + workload.num_train_examples) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng): + + def _loss_fn(params): + """loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) + weight_penalty_params = jax.tree_leaves(params) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) + weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 + loss = loss + weight_penalty + return loss, new_model_state + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del global_step + del eval_results + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + new_optimizer_state, new_params, new_model_state = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, batch, per_device_rngs) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/cifar/cifar_pytorch/submission.py b/reference_submissions/cifar/cifar_pytorch/submission.py new file mode 100644 index 000000000..096d89a37 --- /dev/null +++ b/reference_submissions/cifar/cifar_pytorch/submission.py @@ -0,0 +1,118 @@ +"""Training algorithm track submission functions for CIFAR10.""" +from typing import Dict, Iterator, List, Tuple + +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'cifar': 128} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del workload + del model_state + del rng + + base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 256. + optimizer_state = { + 'optimizer': + torch.optim.SGD( + model_params.parameters(), + lr=base_lr, + momentum=hyperparameters.momentum, + weight_decay=hyperparameters.l2) + } + + scheduler1 = LinearLR( + optimizer_state['optimizer'], + start_factor=1e-5, + end_factor=1., + total_iters=hyperparameters.warmup_epochs) + cosine_epochs = max( + hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) + scheduler2 = CosineAnnealingLR( + optimizer_state['optimizer'], T_max=cosine_epochs) + + optimizer_state['scheduler'] = SequentialLR( + optimizer_state['optimizer'], + schedulers=[scheduler1, scheduler2], + milestones=[hyperparameters.warmup_epochs]) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del hyperparameters + del loss_type + del eval_results + + current_model = current_param_container + current_param_container.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + loss = workload.loss_fn( + label_batch=batch['targets'], logits_batch=logits_batch).mean() + + loss.backward() + optimizer_state['optimizer'].step() + + steps_per_epoch = workload.num_train_examples // get_batch_size('cifar') + if (global_step + 1) % steps_per_epoch == 0: + optimizer_state['scheduler'].step() + + return (optimizer_state, current_param_container, new_model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/cifar/tuning_search_space.json b/reference_submissions/cifar/tuning_search_space.json new file mode 100644 index 000000000..283341705 --- /dev/null +++ b/reference_submissions/cifar/tuning_search_space.json @@ -0,0 +1,7 @@ +{ + "learning_rate": {"feasible_points": [0.1]}, + "warmup_epochs": {"feasible_points": [5]}, + "num_epochs": {"feasible_points": [200]}, + "l2": {"feasible_points": [5e-4]}, + "momentum": {"feasible_points": [0.9]} +} diff --git a/reference_submissions/criteo1tb/criteo1tb_jax/__init__.py b/reference_submissions/criteo1tb/criteo1tb_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/criteo1tb/criteo1tb_jax/submission.py b/reference_submissions/criteo1tb/criteo1tb_jax/submission.py new file mode 100644 index 000000000..ea65bd4f6 --- /dev/null +++ b/reference_submissions/criteo1tb/criteo1tb_jax/submission.py @@ -0,0 +1,131 @@ +"""Training algorithm track submission functions for Criteo1TB DLRM-Small.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 131072 + + +def create_learning_rate_fn(workload: spec.Workload, + hparams: spec.Hyperparameters): + """Create learning rate schedule.""" + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hparams.learning_rate, + transition_steps=hparams.warmup_steps) + cosine_fn = optax.cosine_decay_schedule( + init_value=hparams.learning_rate, + decay_steps=(workload.step_hint - hparams.warmup_steps)) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_steps]) + return schedule_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_params + del model_state + del rng + learning_rate_fn = create_learning_rate_fn(workload, hyperparameters) + opt_init_fn, opt_update_fn = optax.adamw( + b1=hyperparameters.beta1, + learning_rate=learning_rate_fn, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 1)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng): + + def _loss_fn(params): + """loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) + return loss, (new_model_state, logits) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (new_model_state, _), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_model_state, new_optimizer_state, updated_params + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, batch, per_device_rngs) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/criteo1tb/tuning_search_space.json b/reference_submissions/criteo1tb/tuning_search_space.json new file mode 100644 index 000000000..1da5349f5 --- /dev/null +++ b/reference_submissions/criteo1tb/tuning_search_space.json @@ -0,0 +1,6 @@ +{ + "learning_rate": {"feasible_points": [7e-4]}, + "warmup_steps": {"feasible_points": [3200]}, + "weight_decay": {"feasible_points": [1e-4]}, + "beta1": {"feasible_points": [0.9]} +} diff --git a/reference_submissions/fastmri/__init__.py b/reference_submissions/fastmri/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/fastmri/fastmri_pytorch/__init__.py b/reference_submissions/fastmri/fastmri_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/fastmri/fastmri_pytorch/submission.py b/reference_submissions/fastmri/fastmri_pytorch/submission.py new file mode 100644 index 000000000..d3562621d --- /dev/null +++ b/reference_submissions/fastmri/fastmri_pytorch/submission.py @@ -0,0 +1,104 @@ +"""Training algorithm track submission functions for FastMRI.""" + +from typing import Dict, Iterator, List, Tuple + +import torch +from torch.optim.lr_scheduler import StepLR + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'fastmri': 8} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del workload + del model_state + del rng + + base_lr = hyperparameters.learning_rate * get_batch_size('fastmri') + optimizer_state = { + 'optimizer': + torch.optim.RMSprop( + model_params.parameters(), + lr=base_lr, + weight_decay=hyperparameters.l2) + } + + optimizer_state['scheduler'] = StepLR( + optimizer_state['optimizer'], + step_size=hyperparameters.lr_step_size, + gamma=hyperparameters.lr_gamma) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del hyperparameters + del loss_type + del eval_results + + current_model = current_param_container + current_param_container.train() + optimizer_state['optimizer'].zero_grad() + + outputs_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + loss = workload.loss_fn( + targets_batch=batch['targets'], outputs_batch=outputs_batch).mean() + + loss.backward() + optimizer_state['optimizer'].step() + steps_per_epoch = workload.num_train_examples // get_batch_size('fastmri') + if (global_step + 1) % steps_per_epoch == 0: + optimizer_state['scheduler'].step() + + return (optimizer_state, current_param_container, new_model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/fastmri/tuning_search_space.json b/reference_submissions/fastmri/tuning_search_space.json new file mode 100644 index 000000000..01e4e00c2 --- /dev/null +++ b/reference_submissions/fastmri/tuning_search_space.json @@ -0,0 +1,7 @@ +{ + "learning_rate": {"feasible_points": [0.001]}, + "num_epochs": {"feasible_points": [50]}, + "l2": {"feasible_points": [0.0]}, + "lr_step_size": {"feasible_points": [40]}, + "lr_gamma": {"feasible_points": [0.1]} +} \ No newline at end of file diff --git a/reference_submissions/imagenet_resnet/__init__.py b/reference_submissions/imagenet_resnet/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_resnet/imagenet_jax/__init__.py b/reference_submissions/imagenet_resnet/imagenet_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_resnet/imagenet_jax/submission.py b/reference_submissions/imagenet_resnet/imagenet_jax/submission.py new file mode 100644 index 000000000..7f805bdde --- /dev/null +++ b/reference_submissions/imagenet_resnet/imagenet_jax/submission.py @@ -0,0 +1,150 @@ +"""Training algorithm track submission functions for ImageNet.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 512 + + +def create_learning_rate_fn(hparams: spec.Hyperparameters, + steps_per_epoch: int): + """Create learning rate schedule.""" + base_learning_rate = hparams.learning_rate * \ + get_batch_size('imagenet_resnet') / 256. + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=base_learning_rate, + transition_steps=hparams.warmup_epochs * steps_per_epoch) + cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=base_learning_rate, + decay_steps=cosine_epochs * steps_per_epoch) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], + boundaries=[hparams.warmup_epochs * steps_per_epoch]) + return schedule_fn + + +def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): + steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') + learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) + opt_init_fn, opt_update_fn = optax.sgd( + nesterov=True, + momentum=hyperparameters.momentum, + learning_rate=learning_rate_fn) + return opt_init_fn, opt_update_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_params + del model_state + del rng + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + opt_init_fn, opt_update_fn = optimizer(hyperparameters, + workload.num_train_examples) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng): + + def _loss_fn(params): + """loss function used for training.""" + variables = {'params': params, **model_state} + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) + weight_penalty_params = jax.tree_leaves(variables['params']) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) + weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 + loss = loss + weight_penalty + return loss, (new_model_state, logits) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + aux, grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + new_model_state, _ = aux[1] + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + + return new_model_state, new_optimizer_state, updated_params + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, batch, per_device_rngs) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/imagenet_resnet/imagenet_pytorch/__init__.py b/reference_submissions/imagenet_resnet/imagenet_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_resnet/imagenet_pytorch/submission.py b/reference_submissions/imagenet_resnet/imagenet_pytorch/submission.py new file mode 100644 index 000000000..3f2eecba1 --- /dev/null +++ b/reference_submissions/imagenet_resnet/imagenet_pytorch/submission.py @@ -0,0 +1,118 @@ +"""Training algorithm track submission functions for ImageNet.""" +from typing import Dict, Iterator, List, Tuple + +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 512 + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_state + del rng + + batch_size = get_batch_size('imagenet_resnet') + base_lr = hyperparameters.learning_rate * batch_size / 256. + optimizer_state = { + 'optimizer': + torch.optim.SGD( + model_params.parameters(), + lr=base_lr, + momentum=hyperparameters.momentum, + weight_decay=hyperparameters.l2, + nesterov=True) + } + + steps_per_epoch = workload.num_train_examples // batch_size + scheduler1 = LinearLR( + optimizer_state['optimizer'], + start_factor=1e-10, + end_factor=1., + total_iters=hyperparameters.warmup_epochs * steps_per_epoch) + cosine_epochs = max( + hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) + scheduler2 = CosineAnnealingLR( + optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) + + optimizer_state['scheduler'] = SequentialLR( + optimizer_state['optimizer'], + schedulers=[scheduler1, scheduler2], + milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del hyperparameters + del loss_type + del eval_results + del global_step + + current_model = current_param_container + current_param_container.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + loss = workload.loss_fn( + label_batch=batch['targets'], logits_batch=logits_batch).mean() + + loss.backward() + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + return (optimizer_state, current_param_container, new_model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/imagenet_resnet/tuning_search_space.json b/reference_submissions/imagenet_resnet/tuning_search_space.json new file mode 100644 index 000000000..da969416b --- /dev/null +++ b/reference_submissions/imagenet_resnet/tuning_search_space.json @@ -0,0 +1,7 @@ +{ + "learning_rate": {"feasible_points": [0.1]}, + "warmup_epochs": {"feasible_points": [5]}, + "num_epochs": {"feasible_points": [100]}, + "l2": {"feasible_points": [1e-4]}, + "momentum": {"feasible_points": [0.9]} +} \ No newline at end of file diff --git a/reference_submissions/imagenet_vit/__init__.py b/reference_submissions/imagenet_vit/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_vit/imagenet_jax/__init__.py b/reference_submissions/imagenet_vit/imagenet_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_vit/imagenet_jax/submission.py b/reference_submissions/imagenet_vit/imagenet_jax/submission.py new file mode 100644 index 000000000..3a71f7b51 --- /dev/null +++ b/reference_submissions/imagenet_vit/imagenet_jax/submission.py @@ -0,0 +1,150 @@ +"""Training algorithm track submission functions for ImageNet.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 2048 + + +def create_learning_rate_fn(hparams: spec.Hyperparameters, + steps_per_epoch: int): + """Create learning rate schedule.""" + base_learning_rate = hparams.learning_rate * \ + get_batch_size('imagenet_vit') / 1024. + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=base_learning_rate, + transition_steps=hparams.warmup_epochs * steps_per_epoch) + cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=base_learning_rate, + decay_steps=cosine_epochs * steps_per_epoch) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], + boundaries=[hparams.warmup_epochs * steps_per_epoch]) + return schedule_fn + + +def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): + steps_per_epoch = num_train_examples // get_batch_size('imagenet_vit') + learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) + opt_init_fn, opt_update_fn = optax.adam( + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=hyperparameters.epsilon, + learning_rate=learning_rate_fn) + return opt_init_fn, opt_update_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_params + del model_state + del rng + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + opt_init_fn, opt_update_fn = optimizer(hyperparameters, + workload.num_train_examples) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng): + + def _loss_fn(params): + """loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = jnp.mean(workload.loss_fn(batch['targets'], logits)) + weight_penalty_params = jax.tree_leaves(params) + weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) + weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 + loss = loss + weight_penalty + return loss, (new_model_state, logits) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + aux, grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + new_model_state, _ = aux[1] + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + + return new_model_state, new_optimizer_state, updated_params + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, batch, per_device_rngs) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/imagenet_vit/imagenet_pytorch/__init__.py b/reference_submissions/imagenet_vit/imagenet_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/imagenet_vit/imagenet_pytorch/submission.py b/reference_submissions/imagenet_vit/imagenet_pytorch/submission.py new file mode 100644 index 000000000..c8f45d438 --- /dev/null +++ b/reference_submissions/imagenet_vit/imagenet_pytorch/submission.py @@ -0,0 +1,117 @@ +"""Training algorithm track submission functions for ImageNet.""" +from typing import Dict, Iterator, List, Tuple + +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + del workload_name + return 2048 + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_state + del rng + + batch_size = get_batch_size('imagenet_vit') + base_lr = hyperparameters.learning_rate * batch_size / 1024. + optimizer_state = { + 'optimizer': + torch.optim.Adam( + model_params.parameters(), + lr=base_lr, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=hyperparameters.epsilon) + } + + steps_per_epoch = workload.num_train_examples // batch_size + scheduler1 = LinearLR( + optimizer_state['optimizer'], + start_factor=1e-10, + end_factor=1., + total_iters=hyperparameters.warmup_epochs * steps_per_epoch) + cosine_epochs = max( + hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) + scheduler2 = CosineAnnealingLR( + optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) + + optimizer_state['scheduler'] = SequentialLR( + optimizer_state['optimizer'], + schedulers=[scheduler1, scheduler2], + milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del hyperparameters + del loss_type + del eval_results + del global_step + + current_model = current_param_container + current_param_container.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + loss = workload.loss_fn( + label_batch=batch['targets'], logits_batch=logits_batch).mean() + + loss.backward() + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + return (optimizer_state, current_param_container, new_model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/imagenet_vit/tuning_search_space.json b/reference_submissions/imagenet_vit/tuning_search_space.json new file mode 100644 index 000000000..e6cf84733 --- /dev/null +++ b/reference_submissions/imagenet_vit/tuning_search_space.json @@ -0,0 +1,9 @@ +{ + "learning_rate": {"feasible_points": [1e-3]}, + "beta1": {"feasible_points": [0.9]}, + "beta2": {"feasible_points": [0.999]}, + "epsilon": {"feasible_points": [1e-8]}, + "num_epochs": {"feasible_points": [100]}, + "warmup_epochs": {"feasible_points": [5]}, + "l2": {"feasible_points": [1e-1]} +} \ No newline at end of file diff --git a/reference_submissions/librispeech/__init__.py b/reference_submissions/librispeech/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/librispeech/librispeech_pytorch/__init__.py b/reference_submissions/librispeech/librispeech_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/librispeech/librispeech_pytorch/submission.py b/reference_submissions/librispeech/librispeech_pytorch/submission.py new file mode 100644 index 000000000..bcff72a0f --- /dev/null +++ b/reference_submissions/librispeech/librispeech_pytorch/submission.py @@ -0,0 +1,85 @@ +"""Training algorithm track submission functions for LibriSpeech.""" +from typing import Dict, Iterator, List, Tuple + +import torch + +from algorithmic_efficiency import spec + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +ctc_loss = torch.nn.CTCLoss(blank=0, reduction="none") + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {"librispeech": 8} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del workload + del model_state + del rng + + optimizer = torch.optim.Adam(model_params.parameters(), + hyperparameters.learning_rate) + return optimizer + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del eval_results + del global_step + del model_state + del loss_type + del hyperparameters + + optimizer_state.zero_grad() + + (log_y, output_lengths), _ = workload.model_fn( + current_param_container, batch, None, + spec.ForwardPassMode.TRAIN, rng, False) + + train_ctc_loss = torch.mean(workload.loss_fn(batch, (log_y, output_lengths))) + train_ctc_loss.backward() + optimizer_state.step() + + return optimizer_state, current_param_container, None + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del optimizer_state + del current_param_container + del global_step + del rng + del hyperparameters + del workload + return next(input_queue) diff --git a/reference_submissions/librispeech/tuning_search_space.json b/reference_submissions/librispeech/tuning_search_space.json new file mode 100644 index 000000000..2e981e174 --- /dev/null +++ b/reference_submissions/librispeech/tuning_search_space.json @@ -0,0 +1,4 @@ +{ + "learning_rate": {"feasible_points": [0.0001, 0.0002, 0.0003]} +} + diff --git a/reference_submissions/mnist/__init__.py b/reference_submissions/mnist/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/mnist/mnist_jax/__init__.py b/reference_submissions/mnist/mnist_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/mnist/mnist_jax/submission.py b/reference_submissions/mnist/mnist_jax/submission.py new file mode 100644 index 000000000..275eed50e --- /dev/null +++ b/reference_submissions/mnist/mnist_jax/submission.py @@ -0,0 +1,129 @@ +"""Training algorithm track submission functions for MNIST.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'mnist': 1024} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_params + del model_state + del rng + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + opt_init_fn, opt_update_fn = optax.chain( + optax.scale_by_adam( + b1=1.0 - hyperparameters.one_minus_beta_1, + b2=0.999, + eps=hyperparameters.epsilon), + optax.scale(-hyperparameters.learning_rate)) + return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn + + +# We need to jax.pmap here instead of inside update_params because the latter +# would recompile the function every step. +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, None, 0, 0, 0), + static_broadcasted_argnums=(0, 1)) +def pmapped_update_params(workload: spec.Workload, + opt_update_fn, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + optimizer_state: spec.OptimizerState, + rng: spec.RandomState) -> spec.UpdateReturn: + del hyperparameters + + def loss_fn(params): + logits_batch, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss = workload.loss_fn(batch['targets'], logits_batch) + return jnp.mean(loss), new_model_state + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + optimizer_state, opt_update_fn = optimizer_state + new_optimizer_state, updated_params, new_model_state = pmapped_update_params( + workload, + opt_update_fn, + current_param_container, + model_state, + hyperparameters, + batch, + optimizer_state, + per_device_rngs) + return (new_optimizer_state, opt_update_fn), updated_params, new_model_state + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/mnist/mnist_pytorch/__init__.py b/reference_submissions/mnist/mnist_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/mnist/mnist_pytorch/submission.py b/reference_submissions/mnist/mnist_pytorch/submission.py new file mode 100644 index 000000000..1b4b86fb1 --- /dev/null +++ b/reference_submissions/mnist/mnist_pytorch/submission.py @@ -0,0 +1,90 @@ +"""Training algorithm track submission functions for MNIST.""" +from typing import Dict, Iterator, List, Tuple + +import torch + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'mnist': 1024} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del rng + del model_state + del workload + + optimizer_state = { + 'optimizer': + torch.optim.Adam( + model_params.parameters(), lr=hyperparameters.learning_rate) + } + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del hyperparameters + del loss_type + del current_params_types + del eval_results + del global_step + + current_model = current_param_container + current_param_container.train() + optimizer_state['optimizer'].zero_grad() + + output, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + loss = workload.loss_fn( + label_batch=batch['targets'], logits_batch=output).mean() + + loss.backward() + optimizer_state['optimizer'].step() + + return (optimizer_state, current_param_container, new_model_state) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del optimizer_state + del current_param_container + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/mnist/tuning_search_space.json b/reference_submissions/mnist/tuning_search_space.json new file mode 100644 index 000000000..35b941133 --- /dev/null +++ b/reference_submissions/mnist/tuning_search_space.json @@ -0,0 +1,5 @@ +{ + "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"}, + "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"}, + "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]} +} diff --git a/reference_submissions/ogbg/__init__.py b/reference_submissions/ogbg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/ogbg/ogbg_jax/__init__.py b/reference_submissions/ogbg/ogbg_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/ogbg/ogbg_jax/submission.py b/reference_submissions/ogbg/ogbg_jax/submission.py new file mode 100644 index 000000000..1aaad75ee --- /dev/null +++ b/reference_submissions/ogbg/ogbg_jax/submission.py @@ -0,0 +1,121 @@ +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'ogbg': 2048} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates an Adam optimizer.""" + del model_params + del model_state + del rng + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( + learning_rate=hyperparameters.learning_rate) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + hyperparameters, + batch, + rng): + del hyperparameters + + def loss_fn(params): + logits_batch, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + mask_batch = batch['weights'] + per_example_losses = workload.loss_fn(batch['targets'], + logits_batch, + mask_batch) + mean_loss = ( + jnp.sum(jnp.where(mask_batch, per_example_losses, 0)) / + jnp.sum(mask_batch)) + return mean_loss, new_model_state + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + (_, new_model_state), grad = grad_fn(current_param_container) + grad = lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_model_state, new_optimizer_state, updated_params + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + # This will define the output activation via `output_activation_fn`. + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del global_step + + optimizer_state, opt_update_fn = optimizer_state + pmapped_train_step = jax.pmap( + train_step, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1)) + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + new_model_state, new_optimizer_state, new_params = pmapped_train_step( + workload, opt_update_fn, model_state, optimizer_state, + current_param_container, hyperparameters, batch, dropout_rngs) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/ogbg/ogbg_pytorch/__init__.py b/reference_submissions/ogbg/ogbg_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/ogbg/ogbg_pytorch/submission.py b/reference_submissions/ogbg/ogbg_pytorch/submission.py new file mode 100644 index 000000000..2816c9e4b --- /dev/null +++ b/reference_submissions/ogbg/ogbg_pytorch/submission.py @@ -0,0 +1,91 @@ +from typing import Dict, Iterator, List, Tuple + +import torch + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + # Return the global batch size. + batch_sizes = {'ogbg': 2048} + return batch_sizes[workload_name] + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates an Adam optimizer.""" + del workload + del model_state + del rng + optimizer_state = { + 'optimizer': + torch.optim.Adam( + model_params.parameters(), lr=hyperparameters.learning_rate) + } + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + # This will define the output activation via `output_activation_fn`. + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del hyperparameters + del loss_type + del eval_results + del global_step + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + mask = batch['weights'] + per_example_losses = workload.loss_fn(batch['targets'], logits, mask) + loss = torch.where(mask, per_example_losses, 0).sum() / mask.sum() + + loss.backward() + optimizer_state['optimizer'].step() + + return optimizer_state, current_param_container, new_model_state + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue) diff --git a/reference_submissions/ogbg/tuning_search_space.json b/reference_submissions/ogbg/tuning_search_space.json new file mode 100644 index 000000000..d50cc00c5 --- /dev/null +++ b/reference_submissions/ogbg/tuning_search_space.json @@ -0,0 +1 @@ +{"learning_rate": {"feasible_points": [1e-3]}} diff --git a/reference_submissions/wmt/__init__.py b/reference_submissions/wmt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/wmt/tuning_search_space.json b/reference_submissions/wmt/tuning_search_space.json new file mode 100644 index 000000000..ac12923ee --- /dev/null +++ b/reference_submissions/wmt/tuning_search_space.json @@ -0,0 +1,8 @@ +{ + "learning_rate": {"feasible_points": [0.0625]}, + "one_minus_beta_1": {"feasible_points": [0.1]}, + "dropout_rate": {"feasible_points": [0.1]}, + "attention_dropout_rate": {"feasible_points": [0.1]}, + "epsilon": {"feasible_points": [1e-9]} +} + diff --git a/reference_submissions/wmt/wmt_jax/__init__.py b/reference_submissions/wmt/wmt_jax/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/wmt/wmt_jax/submission.py b/reference_submissions/wmt/wmt_jax/submission.py new file mode 100644 index 000000000..f334d9b47 --- /dev/null +++ b/reference_submissions/wmt/wmt_jax/submission.py @@ -0,0 +1,205 @@ +"""Training algorithm track submission functions for WMT.""" + +import functools +from typing import Dict, Iterator, List, Tuple + +from flax import jax_utils +from flax import linen as nn +from flax.training import common_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + batch_sizes = {"wmt": 128} + return batch_sizes[workload_name] + + +def create_learning_rate_scheduler( + factors="constant * linear_warmup * rsqrt_decay", + base_learning_rate=0.5, + warmup_steps=1000, + decay_factor=0.5, + steps_per_decay=20000, + steps_per_cycle=100000): + """Creates learning rate schedule. + + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. + + Args: + factors: string, factors separated by "*" that defines the schedule. + base_learning_rate: float, the starting constant for the lr schedule. + warmup_steps: int, how many steps to warm up for in the warmup schedule. + decay_factor: float, the amount to decay the learning rate by. + steps_per_decay: int, how often to decay the learning rate. + steps_per_cycle: int, steps per cycle when using cosine decay. + + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split("*")] + + def step_fn(step): + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= base_learning_rate + elif name == "linear_warmup": + ret *= jnp.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "rsqrt_normalized_decay": + ret *= jnp.sqrt(warmup_steps) + ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= (decay_factor**(step // steps_per_decay)) + elif name == "cosine_decay": + progress = jnp.maximum(0.0, + (step - warmup_steps) / float(steps_per_cycle)) + ret *= jnp.maximum(0.0, + 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) + else: + raise ValueError(f"Unknown factor {name}.") + return jnp.asarray(ret, dtype=jnp.float32) + + return step_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del model_state + del rng + learning_rate_fn = create_learning_rate_scheduler( + base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000) + opt_init_fn, opt_update_fn = optax.adam( + b1=1.0 - hyperparameters.one_minus_beta_1, + b2=0.98, + eps=hyperparameters.epsilon, + learning_rate=learning_rate_fn) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + in_axes=(None, None, 0, 0, 0, None, 0), + axis_name='batch', + static_broadcasted_argnums=(0, 1, 5)) +def pmapped_train_step(workload, + opt_update_fn, + optimizer_state, + current_param_container, + batch, + hyperparameters, + dropout_rng): + """Perform a single training step.""" + + def loss_fn(params): + """Loss function used for training.""" + logits, _ = workload.model_fn( + params, + batch, + model_state=None, + mode=spec.ForwardPassMode.TRAIN, + rng=dropout_rng, + update_batch_norm=False) + vocab_size = logits.shape[-1] + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.1 + confidence = 1.0 - label_smoothing + low_confidence = (1.0 - confidence) / (vocab_size - 1) + normalizing_constant = -( + confidence * jnp.log(confidence) + + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) + targets = batch['targets'] + weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) + soft_targets = common_utils.onehot( + targets, vocab_size, on_value=confidence, off_value=low_confidence) + loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) + loss = loss - normalizing_constant + loss = loss * weights + normalizing_factor = weights.sum() + mean_loss = loss.sum() / normalizing_factor + return mean_loss + + grad_fn = jax.value_and_grad(loss_fn) + _, grad = grad_fn(current_param_container) + grad = jax.lax.pmean(grad, axis_name='batch') + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + # This will define the output activation via `output_activation_fn`. + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del eval_results + del global_step + del model_state + del loss_type + + optimizer_state, opt_update_fn = optimizer_state + dropout_rngs = jax.random.split(rng, jax.local_device_count()) + new_optimizer_state, updated_params = pmapped_train_step( + workload, + opt_update_fn, + optimizer_state, + current_param_container, + batch, + hyperparameters, + dropout_rngs) + return (new_optimizer_state, opt_update_fn), updated_params, None + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del optimizer_state + del current_param_container + del global_step + del rng + del hyperparameters + del workload + + return next(input_queue) diff --git a/reference_submissions/wmt/wmt_pytorch/__init__.py b/reference_submissions/wmt/wmt_pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/reference_submissions/wmt/wmt_pytorch/submission.py b/reference_submissions/wmt/wmt_pytorch/submission.py new file mode 100644 index 000000000..bae32fe17 --- /dev/null +++ b/reference_submissions/wmt/wmt_pytorch/submission.py @@ -0,0 +1,153 @@ +from typing import Dict, Iterator, List, Tuple + +import numpy as np +import torch + +from algorithmic_efficiency import spec + + +def get_batch_size(workload_name): + batch_sizes = {'wmt': 128} + return batch_sizes[workload_name] + + +def create_learning_rate_scheduler( + factors="constant * linear_warmup * rsqrt_decay", + base_learning_rate=0.5, + warmup_steps=1000, + decay_factor=0.5, + steps_per_decay=20000, + steps_per_cycle=100000): + """Creates learning rate schedule. + Interprets factors in the factors string which can consist of: + * constant: interpreted as the constant value, + * linear_warmup: interpreted as linear warmup until warmup_steps, + * rsqrt_decay: divide by square root of max(step, warmup_steps) + * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) + * decay_every: Every k steps decay the learning rate by decay_factor. + * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. + Args: + factors: string, factors separated by "*" that defines the schedule. + base_learning_rate: float, the starting constant for the lr schedule. + warmup_steps: int, how many steps to warm up for in the warmup schedule. + decay_factor: float, the amount to decay the learning rate by. + steps_per_decay: int, how often to decay the learning rate. + steps_per_cycle: int, steps per cycle when using cosine decay. + Returns: + a function learning_rate(step): float -> {"learning_rate": float}, the + step-dependent lr. + """ + factors = [n.strip() for n in factors.split("*")] + + def step_fn(step): + """Step to learning rate function.""" + ret = 1.0 + for name in factors: + if name == "constant": + ret *= base_learning_rate + elif name == "linear_warmup": + ret *= np.minimum(1.0, step / warmup_steps) + elif name == "rsqrt_decay": + ret /= np.sqrt(np.maximum(step, warmup_steps)) + elif name == "rsqrt_normalized_decay": + ret *= np.sqrt(warmup_steps) + ret /= np.sqrt(np.maximum(step, warmup_steps)) + elif name == "decay_every": + ret *= (decay_factor**(step // steps_per_decay)) + elif name == "cosine_decay": + progress = np.maximum(0.0, + (step - warmup_steps) / float(steps_per_cycle)) + ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) + else: + raise ValueError("Unknown factor %s." % name) + return ret + + return step_fn + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + del workload + del model_state + del rng + + optimizer_state = { + 'optimizer': + torch.optim.Adam( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98), + eps=hyperparameters.epsilon) + } + + optimizer_state['scheduler'] = create_learning_rate_scheduler( + base_learning_rate=hyperparameters.learning_rate) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del current_params_types + del eval_results + del loss_type + del hyperparameters + + current_model = current_param_container + current_param_container.train() + optimizer = optimizer_state['optimizer'] + optimizer.zero_grad() + + logits, _ = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=False) + + targets = batch['targets'] + weights = torch.where(targets > 0, 1.0, 0.0) + loss = (workload.loss_fn(targets, logits) * weights).sum() / weights.sum() + loss.backward() + + lr = optimizer_state['scheduler'](global_step).item() + for g in optimizer.param_groups: + g['lr'] = lr + optimizer.step() + + return (optimizer_state, current_param_container, None) + + +# Not allowed to update the model parameters, hyperparameters, global step, or +# optimzier state. +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del hyperparameters + del global_step + del rng + return next(input_queue)