Skip to content

Commit

Permalink
Add back reference_submissions
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Aug 24, 2022
1 parent 25770a8 commit f4b151f
Show file tree
Hide file tree
Showing 48 changed files with 1,968 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
pylint algorithmic_efficiency
pylint baselines
pylint target_setting_runs
pylint reference_submissions
pylint submission_runner.py
pylint tests
Expand Down
Empty file.
Empty file.
Empty file.
151 changes: 151 additions & 0 deletions reference_submissions/cifar/cifar_jax/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""Training algorithm track submission functions for CIFAR10."""

import functools
from typing import Dict, Iterator, List, Tuple

from flax import jax_utils
import jax
from jax import lax
import jax.numpy as jnp
import optax

from algorithmic_efficiency import spec


def get_batch_size(workload_name):
# Return the global batch size.
del workload_name
return 128


def cosine_decay(lr, step, total_steps):
ratio = jnp.maximum(0., step / total_steps)
mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio))
return mult * lr


def create_learning_rate_fn(hparams: spec.Hyperparameters,
steps_per_epoch: int):
"""Create learning rate schedule."""
base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 256.
warmup_fn = optax.linear_schedule(
init_value=0.,
end_value=base_learning_rate,
transition_steps=hparams.warmup_epochs * steps_per_epoch)
cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
init_value=base_learning_rate,
decay_steps=cosine_epochs * steps_per_epoch)
schedule_fn = optax.join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[hparams.warmup_epochs * steps_per_epoch])
return schedule_fn


def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int):
steps_per_epoch = num_train_examples // get_batch_size('cifar')
learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch)
opt_init_fn, opt_update_fn = optax.sgd(
nesterov=True,
momentum=hyperparameters.momentum,
learning_rate=learning_rate_fn)
return opt_init_fn, opt_update_fn


def init_optimizer_state(workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState) -> spec.OptimizerState:
del model_params
del model_state
del rng
params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
workload.param_shapes)
opt_init_fn, opt_update_fn = optimizer(hyperparameters,
workload.num_train_examples)
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn


@functools.partial(
jax.pmap,
axis_name='batch',
in_axes=(None, None, 0, 0, 0, None, 0, 0),
static_broadcasted_argnums=(0, 1))
def pmapped_train_step(workload,
opt_update_fn,
model_state,
optimizer_state,
current_param_container,
hyperparameters,
batch,
rng):

def _loss_fn(params):
"""loss function used for training."""
logits, new_model_state = workload.model_fn(
params,
batch,
model_state,
spec.ForwardPassMode.TRAIN,
rng,
update_batch_norm=True)
loss = jnp.mean(workload.loss_fn(batch['targets'], logits))
weight_penalty_params = jax.tree_leaves(params)
weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1)
weight_penalty = hyperparameters.l2 * 0.5 * weight_l2
loss = loss + weight_penalty
return loss, new_model_state

grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(_, new_model_state), grad = grad_fn(current_param_container)
grad = lax.pmean(grad, axis_name='batch')
updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
current_param_container)
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state


def update_params(workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
del global_step
del eval_results
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
new_optimizer_state, new_params, new_model_state = pmapped_train_step(
workload, opt_update_fn, model_state, optimizer_state,
current_param_container, hyperparameters, batch, per_device_rngs)
return (new_optimizer_state, opt_update_fn), new_params, new_model_state


def data_selection(workload: spec.Workload,
input_queue: Iterator[Dict[str, spec.Tensor]],
optimizer_state: spec.OptimizerState,
current_param_container: spec.ParameterContainer,
hyperparameters: spec.Hyperparameters,
global_step: int,
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
del workload
del optimizer_state
del current_param_container
del hyperparameters
del global_step
del rng
return next(input_queue)
118 changes: 118 additions & 0 deletions reference_submissions/cifar/cifar_pytorch/submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Training algorithm track submission functions for CIFAR10."""
from typing import Dict, Iterator, List, Tuple

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import SequentialLR

from algorithmic_efficiency import spec


def get_batch_size(workload_name):
# Return the global batch size.
batch_sizes = {'cifar': 128}
return batch_sizes[workload_name]


def init_optimizer_state(workload: spec.Workload,
model_params: spec.ParameterContainer,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState) -> spec.OptimizerState:
del workload
del model_state
del rng

base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 256.
optimizer_state = {
'optimizer':
torch.optim.SGD(
model_params.parameters(),
lr=base_lr,
momentum=hyperparameters.momentum,
weight_decay=hyperparameters.l2)
}

scheduler1 = LinearLR(
optimizer_state['optimizer'],
start_factor=1e-5,
end_factor=1.,
total_iters=hyperparameters.warmup_epochs)
cosine_epochs = max(
hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1)
scheduler2 = CosineAnnealingLR(
optimizer_state['optimizer'], T_max=cosine_epochs)

optimizer_state['scheduler'] = SequentialLR(
optimizer_state['optimizer'],
schedulers=[scheduler1, scheduler2],
milestones=[hyperparameters.warmup_epochs])

return optimizer_state


def update_params(
workload: spec.Workload,
current_param_container: spec.ParameterContainer,
current_params_types: spec.ParameterTypeTree,
model_state: spec.ModelAuxiliaryState,
hyperparameters: spec.Hyperparameters,
batch: Dict[str, spec.Tensor],
# This will define the output activation via `output_activation_fn`.
loss_type: spec.LossType,
optimizer_state: spec.OptimizerState,
eval_results: List[Tuple[int, float]],
global_step: int,
rng: spec.RandomState) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
del loss_type
del eval_results

current_model = current_param_container
current_param_container.train()
optimizer_state['optimizer'].zero_grad()

logits_batch, new_model_state = workload.model_fn(
params=current_model,
augmented_and_preprocessed_input_batch=batch,
model_state=model_state,
mode=spec.ForwardPassMode.TRAIN,
rng=rng,
update_batch_norm=True)

loss = workload.loss_fn(
label_batch=batch['targets'], logits_batch=logits_batch).mean()

loss.backward()
optimizer_state['optimizer'].step()

steps_per_epoch = workload.num_train_examples // get_batch_size('cifar')
if (global_step + 1) % steps_per_epoch == 0:
optimizer_state['scheduler'].step()

return (optimizer_state, current_param_container, new_model_state)


# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
def data_selection(workload: spec.Workload,
input_queue: Iterator[Dict[str, spec.Tensor]],
optimizer_state: spec.OptimizerState,
current_param_container: spec.ParameterContainer,
hyperparameters: spec.Hyperparameters,
global_step: int,
rng: spec.RandomState) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
del workload
del optimizer_state
del current_param_container
del hyperparameters
del global_step
del rng
return next(input_queue)
7 changes: 7 additions & 0 deletions reference_submissions/cifar/tuning_search_space.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"learning_rate": {"feasible_points": [0.1]},
"warmup_epochs": {"feasible_points": [5]},
"num_epochs": {"feasible_points": [200]},
"l2": {"feasible_points": [5e-4]},
"momentum": {"feasible_points": [0.9]}
}
Empty file.
Loading

0 comments on commit f4b151f

Please sign in to comment.