diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md new file mode 100644 index 000000000..614f87b32 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -0,0 +1,83 @@ +# Prize Qualification Baselines +This directory contains the baseine(s) that submissions that must beat to qualify for prizes. + +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The prize qualification baseline submissions for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir=<data_dir> \ + --experiment_dir=<experiment_dir> \ + --experiment_name=<experiment_name> \ + --workload=<workload> \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir=<data_dir> \ + --experiment_dir=<experiment_dir> \ + --experiment_name=t<experiment_name> \ + --workload=<workload>\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The prize qualification baseline submissionss for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir=<data_dir> \ + --experiment_dir=<experiment_dir> \ + --experiment_name=<experiment_name> \ + --workload=<workload> \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir=<data_dir> \ + --experiment_dir=<experiment_dir> \ + --experiment_name=t<experiment_name> \ + --workload=<workload>\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..ef0c11c0d --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..01cffc52e --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..530dd3acf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..c54202e56 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..dd42743e2 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..57da48167 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..ef6e84c94 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch