From 441df6e881e56d189e2777bada3baea44ba437b1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 19:54:30 +0000 Subject: [PATCH 01/10] add treshold submissions --- .../threshold_submissions/README.md | 0 .../external_tuning/jax_nadamw_full_budget.py | 345 ++++++++++++++++++ .../jax_nadamw_target_setting.py | 171 +++++++++ .../pytorch_nadamw_full_budget.py | 212 +++++++++++ .../pytorch_nadamw_target_setting.py | 171 +++++++++ .../self_tuning/jax_nadamw_full_budget.py | 345 ++++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 171 +++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 212 +++++++++++ .../pytorch_nadamw_target_setting.py | 212 +++++++++++ 9 files changed, 1839 insertions(+) create mode 100644 reference_algorithms/threshold_submissions/README.md create mode 100644 reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..21f2a7b2b --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,171 @@ +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" + +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +from flax import jax_utils +import jax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.jax_submission_base import \ + update_params # pylint: disable=unused-import + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: + learning_rate: this is a fixed global scaling factor. + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + weight_decay: strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: a tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this) + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: + b1: decay rate for the exponentially weighted average of grads. + b2: decay rate for the exponentially weighted average of squared grads. + eps: term added to the denominator to improve numerical stability. + eps_root: term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: whether to use bias correction. + power: the power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + target_setting_step_hint = int(0.75 * workload.step_hint) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, + hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.weight_decay) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..71b819e66 --- /dev/null +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,212 @@ +"""Submission file for a NAdamW optimizer in PyTorch.""" + +import math +from typing import List + +import torch +from torch import Tensor + +from algorithmic_efficiency import spec +from reference_algorithms.target_setting_algorithms import cosine_warmup +from reference_algorithms.target_setting_algorithms.data_selection import \ + data_selection # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ + update_params # pylint: disable=unused-import + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state From 94b360fe85425025138e94825fd07c1d4b4384b3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 21:04:25 +0000 Subject: [PATCH 02/10] add tuning search space --- .../jax_nadamw_target_setting.py | 2 +- .../pytorch_nadamw_full_budget.py | 183 ++++++++-- .../pytorch_nadamw_target_setting.py | 333 ++++++++++-------- .../self_tuning/jax_nadamw_target_setting.py | 2 +- .../self_tuning/pytorch_nadamw_full_budget.py | 183 ++++++++-- .../pytorch_nadamw_target_setting.py | 2 +- 6 files changed, 508 insertions(+), 197 deletions(-) diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py index 21f2a7b2b..8f20bcbc6 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -162,7 +162,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, + b1=1 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, eps=epsilon, weight_decay=hyperparameters.weight_decay) diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py index 71b819e66..01cffc52e 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py index 21f2a7b2b..7aa8160a4 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -1,12 +1,10 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for a NAdamW optimizer in PyTorch.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import math +from typing import List -import chex -from flax import jax_utils -import jax -import jax.numpy as jnp -import optax +import torch +from torch import Tensor from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -14,131 +12,177 @@ data_selection # pylint: disable=unused-import from reference_algorithms.target_setting_algorithms.get_batch_size import \ get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ +from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ update_params # pylint: disable=unused-import -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float): + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # update step + step_t += 1 + + # Perform stepweight decay + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) def init_optimizer_state(workload: spec.Workload, @@ -147,25 +191,22 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters: spec.Hyperparameters, rng: spec.RandomState) -> spec.OptimizerState: """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params del model_state del rng - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) - - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.weight_decay) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.weight_decay), + } + + target_setting_step_hint = int(0.75 * workload.step_hint) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py index 21f2a7b2b..8f20bcbc6 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -162,7 +162,7 @@ def init_optimizer_state(workload: spec.Workload, hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=hyperparameters.beta1, + b1=1 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, eps=epsilon, weight_decay=hyperparameters.weight_decay) diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py index 71b819e66..01cffc52e 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py index 71b819e66..7aa8160a4 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -201,7 +201,7 @@ def init_optimizer_state(workload: spec.Workload, NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), + betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), eps=epsilon, weight_decay=hyperparameters.weight_decay), } From 05427690c7b69c10b62a21c6b0cc43dd43dc0fb8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:16:14 +0000 Subject: [PATCH 03/10] add self tuning threshold submission --- .../threshold_submissions/README.md | 80 ++++++ .../jax_nadamw_target_setting.py | 242 ++++++++++++++--- .../pytorch_nadamw_target_setting.py | 183 +++++++++++-- .../self_tuning/jax_nadamw_full_budget.py | 14 + .../self_tuning/jax_nadamw_target_setting.py | 256 +++++++++++++++--- .../self_tuning/pytorch_nadamw_full_budget.py | 14 + .../pytorch_nadamw_target_setting.py | 197 ++++++++++++-- 7 files changed, 870 insertions(+), 116 deletions(-) diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md index e69de29bb..eb8995408 100644 --- a/reference_algorithms/threshold_submissions/README.md +++ b/reference_algorithms/threshold_submissions/README.md @@ -0,0 +1,80 @@ +# Threshold Submissions + +## Externally Tuned Ruleset + +### JAX + +The threshold submissions for jax are: +- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The threshold submissions for jax are +- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py index 8f20bcbc6..ef0c11c0d 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py @@ -1,21 +1,30 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on import chex from flax import jax_utils import jax +from jax import lax import jax.numpy as jnp import optax from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import + +_GRAD_CLIP_EPS = 1e-6 # Forked from @@ -32,6 +41,7 @@ def nadamw( Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch @@ -39,24 +49,26 @@ def nadamw( Current code implements a simpler version with no momentum decay and slightly different bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params + weight_decay_mask: A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the weight decay to, and `False` for those you want to skip. Note that the Nadam gradient transformations are applied to all parameters. + Returns: An (init_fn, update_fn) tuple. """ @@ -75,21 +87,24 @@ def scale_by_nadam(b1: float = 0.9, debias: bool = True, power: float = 0.5) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) + follows this). + Current code implements a simpler version with no momentum decay and slightly different (standard Adam) bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). Returns: An (init_fn, update_fn) tuple. """ @@ -151,21 +166,180 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, + b1=1.0 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, - eps=epsilon, + eps=1e-8, weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py index 7aa8160a4..530dd3acf 100644 --- a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py @@ -1,29 +1,32 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +40,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +75,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +143,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +159,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -194,19 +198,150 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py index 099613fcf..b35750086 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py @@ -26,6 +26,14 @@ _GRAD_CLIP_EPS = 1e-6 +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -165,7 +173,10 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng + del hyperparameters + hyperparameters=HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) @@ -267,6 +278,9 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del eval_results + del hyperparameters + + hyperparameters = HPARAMS optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py index 8f20bcbc6..190720213 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py @@ -1,22 +1,39 @@ -"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" -from typing import Any, Callable, NamedTuple, Optional, Union +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on import chex from flax import jax_utils import jax +from jax import lax import jax.numpy as jnp import optax from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.jax_submission_base import \ - update_params # pylint: disable=unused-import +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -32,6 +49,7 @@ def nadamw( Any]]] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch @@ -39,24 +57,26 @@ def nadamw( Current code implements a simpler version with no momentum decay and slightly different bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1). + Args: - learning_rate: this is a fixed global scaling factor. - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - weight_decay: strength of the weight decay regularization. Note that this + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the "schedule multiplier", but not the base learning rate. - weight_decay_mask: a tree with same structure as (or a prefix of) the params + weight_decay_mask: A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, `True` for leaves/subtrees you want to apply the weight decay to, and `False` for those you want to skip. Note that the Nadam gradient transformations are applied to all parameters. + Returns: An (init_fn, update_fn) tuple. """ @@ -75,21 +95,24 @@ def scale_by_nadam(b1: float = 0.9, debias: bool = True, power: float = 0.5) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. + References: There seem to be multiple versions of NAdam. The original version is here https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this) + follows this). + Current code implements a simpler version with no momentum decay and slightly different (standard Adam) bias correction terms. The exact description can be found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + Args: - b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of squared grads. - eps: term added to the denominator to improve numerical stability. - eps_root: term added to the denominator inside the square-root to improve + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. - debias: whether to use bias correction. - power: the power to use in the preconditioner (0.5 in default adam). + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). Returns: An (init_fn, update_fn) tuple. """ @@ -150,22 +173,187 @@ def init_optimizer_state(workload: spec.Workload, del model_params del model_state del rng + del hyperparameters - target_setting_step_hint = int(0.75 * workload.step_hint) - lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, - hyperparameters) + hyperparameters=HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, - b1=1 - hyperparameters.one_minus_beta1, + b1=1.0 - hyperparameters.one_minus_beta1, b2=hyperparameters.beta2, - eps=epsilon, + eps=1e-8, weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py index 01cffc52e..a1cf612f2 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py @@ -16,6 +16,14 @@ USE_PYTORCH_DDP = pytorch_setup()[0] +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -197,6 +205,9 @@ def init_optimizer_state(workload: spec.Workload, """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng + del hyperparameters + + hyperparameters = HPARAMS optimizer_state = { 'optimizer': @@ -239,7 +250,10 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del eval_results + del hyperparameters + hyperparameters = HPARAMS + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py index 7aa8160a4..1209abadc 100644 --- a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py @@ -1,29 +1,40 @@ -"""Submission file for a NAdamW optimizer in PyTorch.""" +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import List +from typing import Dict, Iterator, List, Tuple +from absl import logging import torch from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec -from reference_algorithms.target_setting_algorithms import cosine_warmup -from reference_algorithms.target_setting_algorithms.data_selection import \ - data_selection # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.get_batch_size import \ - get_batch_size # pylint: disable=unused-import -from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \ - update_params # pylint: disable=unused-import +from algorithmic_efficiency.pytorch_utils import pytorch_setup +USE_PYTORCH_DDP = pytorch_setup()[0] -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): r"""Implements NAdamW algorithm. + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of the NAdam algorithm (there is also a comment in the code which highlights the only difference of NAdamW and AdamW). For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. + Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups @@ -37,7 +48,7 @@ class NAdamW(torch.optim.Optimizer): https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ - """ + """ def __init__(self, params, @@ -72,10 +83,11 @@ def __setstate__(self, state): @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. - """ + """ self._cuda_graph_capture_health_check() loss = None @@ -139,10 +151,10 @@ def nadamw(params: List[Tensor], beta2: float, lr: float, weight_decay: float, - eps: float): + eps: float) -> None: r"""Functional API that performs NAdamW algorithm computation. See NAdamW class for details. - """ + """ if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError( @@ -155,13 +167,13 @@ def nadamw(params: List[Tensor], exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] - # update step + # Update step. step_t += 1 - # Perform stepweight decay + # Perform stepweight decay. param.mul_(1 - lr * weight_decay) - # Decay the first and second moment running average coefficient + # Decay the first and second moment running average coefficient. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) @@ -193,20 +205,157 @@ def init_optimizer_state(workload: spec.Workload, """Creates a NAdamW optimizer and a learning rate schedule.""" del model_state del rng + del hyperparameters + + hyperparameters = HPARAMS - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) optimizer_state = { 'optimizer': NAdamW( model_params.parameters(), lr=hyperparameters.learning_rate, - betas=(1 - hyperparameters.one_minus_beta1, hyperparameters.beta2), - eps=epsilon, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, weight_decay=hyperparameters.weight_decay), } - target_setting_step_hint = int(0.75 * workload.step_hint) - optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( - target_setting_step_hint, hyperparameters, optimizer_state['optimizer']) + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer']) + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch From 2439135829803cb4130e1b18e6766dc93782a0f7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:22:37 +0000 Subject: [PATCH 04/10] update readme --- reference_algorithms/threshold_submissions/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md index eb8995408..d73706ad7 100644 --- a/reference_algorithms/threshold_submissions/README.md +++ b/reference_algorithms/threshold_submissions/README.md @@ -1,4 +1,5 @@ # Threshold Submissions +TODO: link back to section in rules. ## Externally Tuned Ruleset From cc0c6ff817e29e5bf40f2ae7708463357846ed52 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:23:05 +0000 Subject: [PATCH 05/10] tuning search space --- .../external_tuning/tuning_search_space.json | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json diff --git a/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + From 0338f8fcfbce807aac5bde4fbf278f13a9eeb1b1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:26:37 +0000 Subject: [PATCH 06/10] add reference algorithms --- .../threshold_baselines/README.md | 81 +++++++++++++++++++ .../external_tuning/jax_nadamw_full_budget.py | 0 .../jax_nadamw_target_setting.py | 0 .../pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../external_tuning/tuning_search_space.json | 0 .../self_tuning/jax_nadamw_full_budget.py | 0 .../self_tuning/jax_nadamw_target_setting.py | 0 .../self_tuning/pytorch_nadamw_full_budget.py | 0 .../pytorch_nadamw_target_setting.py | 0 .../threshold_submissions/README.md | 81 ------------------- 11 files changed, 81 insertions(+), 81 deletions(-) create mode 100644 reference_algorithms/threshold_baselines/README.md rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/jax_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/jax_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/pytorch_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/pytorch_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/external_tuning/tuning_search_space.json (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/jax_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/jax_nadamw_target_setting.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/pytorch_nadamw_full_budget.py (100%) rename reference_algorithms/{threshold_submissions => threshold_baselines}/self_tuning/pytorch_nadamw_target_setting.py (100%) delete mode 100644 reference_algorithms/threshold_submissions/README.md diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md new file mode 100644 index 000000000..fa0971997 --- /dev/null +++ b/reference_algorithms/threshold_baselines/README.md @@ -0,0 +1,81 @@ +# Threshold Baselines +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The threshold submissions for jax are: +- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The threshold submissions for jax are +- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The threshold submissions for PyTorch are +- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json similarity index 100% rename from reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json rename to reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/jax_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_full_budget.py rename to reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py diff --git a/reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py similarity index 100% rename from reference_algorithms/threshold_submissions/self_tuning/pytorch_nadamw_target_setting.py rename to reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_submissions/README.md b/reference_algorithms/threshold_submissions/README.md deleted file mode 100644 index d73706ad7..000000000 --- a/reference_algorithms/threshold_submissions/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# Threshold Submissions -TODO: link back to section in rules. - -## Externally Tuned Ruleset - -### JAX - -The threshold submissions for jax are: -- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` - -Example command: - -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json -``` - -### PyTorch - -The threshold submissions for PyTorch are -- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` - - -Example command: - -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_submissions/external_tuning/tuning_search_space.json -``` - -## Self-tuning Ruleset - -### JAX - -The threshold submissions for jax are -- `reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/jax_nadamw_full_budget.py` - -Example command: -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/jax_nadamw_target_setting.py \ - --tuning_ruleset=self -``` - -### PyTorch - -The threshold submissions for PyTorch are -- `reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_full_budget.py` - -Example command: -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_submissions/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_ruleset=self -``` \ No newline at end of file From 5f834047e249b0419c0d2082b6737a1e3cd5c4a3 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:40:25 +0000 Subject: [PATCH 07/10] formatting --- .../self_tuning/jax_nadamw_full_budget.py | 19 +++++++++-------- .../self_tuning/jax_nadamw_target_setting.py | 21 ++++++++++--------- .../self_tuning/pytorch_nadamw_full_budget.py | 19 +++++++++-------- .../pytorch_nadamw_target_setting.py | 21 ++++++++++--------- 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py index b35750086..c54202e56 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py index 190720213..dd42743e2 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py @@ -27,13 +27,14 @@ _GRAD_CLIP_EPS = 1e-6 HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Forked from # github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py @@ -175,8 +176,8 @@ def init_optimizer_state(workload: spec.Workload, del rng del hyperparameters - hyperparameters=HPARAMS - + hyperparameters = HPARAMS + def jax_cosine_warmup(step_hint: int, hyperparameters): # Create learning rate schedule. warmup_steps = int(hyperparameters.warmup_factor * step_hint) @@ -192,7 +193,7 @@ def jax_cosine_warmup(step_hint: int, hyperparameters): return schedule_fn # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint*0.75, hyperparameters) + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) opt_init_fn, opt_update_fn = nadamw( learning_rate=lr_schedule_fn, b1=1.0 - hyperparameters.one_minus_beta1, diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py index a1cf612f2..57da48167 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1209abadc..ef6e84c94 100644 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -16,14 +16,15 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): @@ -230,7 +231,7 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint*0.75, hyperparameters, optimizer_state['optimizer']) + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) return optimizer_state @@ -253,7 +254,7 @@ def update_params(workload: spec.Workload, del hyperparameters hyperparameters = HPARAMS - + current_model = current_param_container current_model.train() optimizer_state['optimizer'].zero_grad() From 811b7c4dab634424fdfdebd28db48850f7eab1ec Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 22:42:42 +0000 Subject: [PATCH 08/10] baselines --- reference_algorithms/threshold_baselines/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md index fa0971997..09eed8f41 100644 --- a/reference_algorithms/threshold_baselines/README.md +++ b/reference_algorithms/threshold_baselines/README.md @@ -5,7 +5,7 @@ TODO: link back to section in rules. ### JAX -The threshold submissions for jax are: +The threshold baseline submissions for jax are: - `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` @@ -24,7 +24,7 @@ python3 submission_runner.py \ ### PyTorch -The threshold submissions for PyTorch are +The threshold baseline submissionss for PyTorch are: - `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` @@ -46,7 +46,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc ### JAX -The threshold submissions for jax are +The threshold baseline submissionss for jax are: - `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` @@ -64,7 +64,7 @@ python3 submission_runner.py \ ### PyTorch -The threshold submissions for PyTorch are +The threshold baseline submissionss for PyTorch are: - `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` - `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` From 30a0654963748b634f9f3a46b5559bfb3536ae77 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 23:57:20 +0000 Subject: [PATCH 09/10] rename prize qualification baselines --- .../threshold_baselines/README.md | 81 ---- .../external_tuning/jax_nadamw_full_budget.py | 345 ----------------- .../jax_nadamw_target_setting.py | 345 ----------------- .../pytorch_nadamw_full_budget.py | 347 ----------------- .../pytorch_nadamw_target_setting.py | 347 ----------------- .../external_tuning/tuning_search_space.json | 50 --- .../self_tuning/jax_nadamw_full_budget.py | 360 ----------------- .../self_tuning/jax_nadamw_target_setting.py | 360 ----------------- .../self_tuning/pytorch_nadamw_full_budget.py | 362 ------------------ .../pytorch_nadamw_target_setting.py | 362 ------------------ 10 files changed, 2959 deletions(-) delete mode 100644 reference_algorithms/threshold_baselines/README.md delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py delete mode 100644 reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/threshold_baselines/README.md b/reference_algorithms/threshold_baselines/README.md deleted file mode 100644 index 09eed8f41..000000000 --- a/reference_algorithms/threshold_baselines/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# Threshold Baselines -TODO: link back to section in rules. - -## Externally Tuned Ruleset - -### JAX - -The threshold baseline submissions for jax are: -- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` - -Example command: - -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json -``` - -### PyTorch - -The threshold baseline submissionss for PyTorch are: -- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` - - -Example command: - -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_search_space=reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json -``` - -## Self-tuning Ruleset - -### JAX - -The threshold baseline submissionss for jax are: -- `reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py` - -Example command: -```bash -python3 submission_runner.py \ - --framework=jax \ - --data_dir= \ - --experiment_dir= \ - --experiment_name= \ - --workload= \ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py \ - --tuning_ruleset=self -``` - -### PyTorch - -The threshold baseline submissionss for PyTorch are: -- `reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py` -- `feference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py` - -Example command: -```bash -torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ - --framework=pytorch \ - --data_dir= \ - --experiment_dir= \ - --experiment_name=t \ - --workload=\ - --submission_path=reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py \ - --tuning_ruleset=self -``` \ No newline at end of file diff --git a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py deleted file mode 100644 index 099613fcf..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_full_budget.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py deleted file mode 100644 index ef0c11c0d..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/jax_nadamw_target_setting.py +++ /dev/null @@ -1,345 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py deleted file mode 100644 index 01cffc52e..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py deleted file mode 100644 index 530dd3acf..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ /dev/null @@ -1,347 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json deleted file mode 100644 index 65562905a..000000000 --- a/reference_algorithms/threshold_baselines/external_tuning/tuning_search_space.json +++ /dev/null @@ -1,50 +0,0 @@ -[ - { - "dropout_rate": 0.0, - "label_smoothing": 0.1, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "label_smoothing": 0.2, - "learning_rate": 0.0008445074561975979, - "one_minus_beta1": 0.11042418465, - "beta2": 0.9978504782314613, - "weight_decay": 0.08135402759553023, - "warmup_factor": 0.05 - }, - { - "dropout_rate": 0.0, - "learning_rate": 0.001308209823469072, - "one_minus_beta1": 0.02686663061, - "beta2": 0.9981232922116359, - "weight_decay": 0.16375311233774334, - "warmup_factor": 0.1 - }, - { - "dropout_rate": 0.0, - "learning_rate": 0.004958460849689891, - "one_minus_beta1": 0.13625575743, - "beta2": 0.6291854735396584, - "weight_decay": 0.1147386261512052, - "warmup_factor": 0.02 - }, - { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 - } -] - - - - - - diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py deleted file mode 100644 index c54202e56..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_full_budget.py +++ /dev/null @@ -1,360 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py deleted file mode 100644 index dd42743e2..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/jax_nadamw_target_setting.py +++ /dev/null @@ -1,360 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" - -import functools - -# isort: off -# We have to turn off isort here to resolve a conflict between isort and yapf. -from typing import (Any, - Callable, - Dict, - Iterator, - List, - NamedTuple, - Optional, - Tuple, - Union) -# isort: on - -import chex -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py -def nadamw( - learning_rate: Union[float, optax.Schedule], - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, -) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch - implementation also follows this). - Current code implements a simpler version with no momentum decay and slightly - different bias correction terms. The exact description can be found here - https://arxiv.org/pdf/1910.05446.pdf (Table 1). - - Args: - learning_rate: A fixed global scaling factor. - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent with - other frameworks such as PyTorch, but different from (Loshchilov et al, - 2019) where the weight decay is only multiplied with the "schedule - multiplier", but not the base learning rate. - weight_decay_mask: A tree with same structure as (or a prefix of) the params - PyTree, or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Nadam gradient transformations are applied to all parameters. - - Returns: - An (init_fn, update_fn) tuple. - """ - return optax.chain( - scale_by_nadam(b1, b2, eps, eps_root, debias), - optax.add_decayed_weights(weight_decay, weight_decay_mask), - scale_by_learning_rate(learning_rate)) - - -# All functions below are forked from -# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py -def scale_by_nadam(b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8, - eps_root: float = 0.0, - debias: bool = True, - power: float = 0.5) -> optax.GradientTransformation: - """Rescale updates according to the NAdam algorithm. - - References: - There seem to be multiple versions of NAdam. The original version is here - https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also - follows this). - - Current code implements a simpler version with no momentum decay and slightly - different (standard Adam) bias correction terms. The exact description can be - found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) - - Args: - b1: Decay rate for the exponentially weighted average of grads. - b2: Decay rate for the exponentially weighted average of squared grads. - eps: Term added to the denominator to improve numerical stability. - eps_root: Term added to the denominator inside the square-root to improve - numerical stability when backpropagating gradients through the rescaling. - debias: Whether to use bias correction. - power: The power to use in the preconditioner (0.5 in default adam). - Returns: - An (init_fn, update_fn) tuple. - """ - raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) - - def init_fn(params): - mu = jax.tree_map(jnp.zeros_like, params) # First moment - nu = jax.tree_map(jnp.zeros_like, params) # Second moment - return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) - - def update_fn(updates, state, params=None): - del params - mu = _update_moment(updates, state.mu, b1, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - mu_hat = _update_moment(updates, mu, b1, 1) - mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) - nu_hat = nu if not debias else _bias_correction(nu, b2, count) - updates = jax.tree_map( - lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) - return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) - - return optax.GradientTransformation(init_fn, update_fn) - - -class ScaleByAdamState(NamedTuple): - """State for the NAdam algorithm.""" - count: chex.Array # shape=(), dtype=jnp.int32. - mu: optax.Updates - nu: optax.Updates - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order-th` moment.""" - return jax.tree_map( - lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) - - -def _bias_correction(moment, decay, count): - """Perform bias correction. This becomes a no-op as count goes to infinity.""" - beta = 1 - decay**count - return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) - - -def scale_by_learning_rate(learning_rate, flip_sign=True): - m = -1 if flip_sign else 1 - if callable(learning_rate): - return optax.scale_by_schedule(lambda count: m * learning_rate(count)) - return optax.scale(m * learning_rate) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - def jax_cosine_warmup(step_hint: int, hyperparameters): - # Create learning rate schedule. - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) - return schedule_fn - - # Create optimizer + LR schedule. - lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=lr_schedule_fn, - b1=1.0 - hyperparameters.one_minus_beta1, - b2=hyperparameters.beta2, - eps=1e-8, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - grad_clip, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters.label_smoothing - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - # Log loss, grad_norm. - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py deleted file mode 100644 index 57da48167..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch diff --git a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py deleted file mode 100644 index ef6e84c94..000000000 --- a/reference_algorithms/threshold_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" - -import math -from typing import Dict, Iterator, List, Tuple - -from absl import logging -import torch -from torch import Tensor -import torch.distributed.nn as dist_nn -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -HPARAMS = { - "dropout_rate": 0.1, - "learning_rate": 0.0017486387539278373, - "one_minus_beta1": 0.06733926164, - "beta2": 0.9955159689799007, - "weight_decay": 0.08121616522670176, - "warmup_factor": 0.02 -} - - -# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. -class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2): - if not 0.0 <= lr: - raise ValueError(f'Invalid learning rate: {lr}') - if not 0.0 <= eps: - raise ValueError(f'Invalid epsilon value: {eps}') - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') - if not 0.0 <= weight_decay: - raise ValueError(f'Invalid weight_decay value: {weight_decay}') - defaults = { - 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay - } - super().__init__(params, defaults) - - def __setstate__(self, state): - super().__setstate__(state) - state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) - if not step_is_tensor: - for s in state_values: - s['step'] = torch.tensor(float(s['step'])) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - self._cuda_graph_capture_health_check() - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - params_with_grad = [] - grads = [] - exp_avgs = [] - exp_avg_sqs = [] - state_steps = [] - beta1, beta2 = group['betas'] - - for p in group['params']: - if p.grad is None: - continue - params_with_grad.append(p) - if p.grad.is_sparse: - raise RuntimeError('NAdamW does not support sparse gradients') - grads.append(p.grad) - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = torch.tensor(0.) - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like( - p, memory_format=torch.preserve_format) - - exp_avgs.append(state['exp_avg']) - exp_avg_sqs.append(state['exp_avg_sq']) - state_steps.append(state['step']) - - nadamw( - params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - state_steps, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps']) - - return loss - - -def nadamw(params: List[Tensor], - grads: List[Tensor], - exp_avgs: List[Tensor], - exp_avg_sqs: List[Tensor], - state_steps: List[Tensor], - beta1: float, - beta2: float, - lr: float, - weight_decay: float, - eps: float) -> None: - r"""Functional API that performs NAdamW algorithm computation. - See NAdamW class for details. - """ - - if not all(isinstance(t, torch.Tensor) for t in state_steps): - raise RuntimeError( - 'API has changed, `state_steps` argument must contain a list of' + - ' singleton tensors') - - for i, param in enumerate(params): - grad = grads[i] - exp_avg = exp_avgs[i] - exp_avg_sq = exp_avg_sqs[i] - step_t = state_steps[i] - - # Update step. - step_t += 1 - - # Perform stepweight decay. - param.mul_(1 - lr * weight_decay) - - # Decay the first and second moment running average coefficient. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # Only difference between NAdamW and AdamW in this implementation. - # The official PyTorch implementation of NAdam uses a different algorithm. - # We undo these ops later on, which could cause numerical issues but saves - # us from having to make an extra copy of the gradients. - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - - step = step_t.item() - - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - step_size = lr / bias_correction1 - - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) - - param.addcdiv_(exp_avg, denom, value=-step_size) - exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_state - del rng - del hyperparameters - - hyperparameters = HPARAMS - - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta1, - hyperparameters.beta2), - eps=1e-8, - weight_decay=hyperparameters.weight_decay), - } - - def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): - warmup_steps = int(hyperparameters.warmup_factor * step_hint) - warmup = LinearLR( - optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) - cosine_steps = max(step_hint - warmup_steps, 1) - cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) - return SequentialLR( - optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) - - optimizer_state['scheduler'] = pytorch_cosine_warmup( - workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del hyperparameters - - hyperparameters = HPARAMS - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters.grad_clip - else: - grad_clip = None - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=grad_clip) - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - # Log training metrics - loss, grad_norm, batch_size. - if global_step <= 100 or global_step % 500 == 0: - with torch.no_grad(): - parameters = [p for p in current_model.parameters() if p.grad is not None] - grad_norm = torch.norm( - torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss.item(), - 'grad_norm': grad_norm.item(), - }, global_step) - logging.info('%d) loss = %0.3f, grad_norm = %0.3f', - global_step, - loss.item(), - grad_norm.item()) - - return (optimizer_state, current_param_container, new_model_state) - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb': - return 262_144 - elif workload_name == 'fastmri': - return 32 - elif workload_name == 'imagenet_resnet': - return 1024 - elif workload_name == 'imagenet_vit': - return 1024 - elif workload_name == 'librispeech_conformer': - return 256 - elif workload_name == 'librispeech_deepspeech': - return 256 - elif workload_name == 'ogbg': - return 512 - elif workload_name == 'wmt': - return 128 - elif workload_name == 'mnist': - return 16 - else: - raise ValueError(f'Unsupported workload name: {workload_name}.') - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - batch = next(input_queue) - return batch From b24e7e7284186cd852f4bb5fdb032148ca69659a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 21 Nov 2023 23:57:37 +0000 Subject: [PATCH 10/10] rename --- .../prize_qualification_baselines/README.md | 83 ++++ .../external_tuning/jax_nadamw_full_budget.py | 345 +++++++++++++++++ .../jax_nadamw_target_setting.py | 345 +++++++++++++++++ .../pytorch_nadamw_full_budget.py | 347 +++++++++++++++++ .../pytorch_nadamw_target_setting.py | 347 +++++++++++++++++ .../external_tuning/tuning_search_space.json | 50 +++ .../self_tuning/jax_nadamw_full_budget.py | 360 +++++++++++++++++ .../self_tuning/jax_nadamw_target_setting.py | 360 +++++++++++++++++ .../self_tuning/pytorch_nadamw_full_budget.py | 362 ++++++++++++++++++ .../pytorch_nadamw_target_setting.py | 362 ++++++++++++++++++ 10 files changed, 2961 insertions(+) create mode 100644 reference_algorithms/prize_qualification_baselines/README.md create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py create mode 100644 reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md new file mode 100644 index 000000000..614f87b32 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/README.md @@ -0,0 +1,83 @@ +# Prize Qualification Baselines +This directory contains the baseine(s) that submissions that must beat to qualify for prizes. + +TODO: link back to section in rules. + +## Externally Tuned Ruleset + +### JAX + +The prize qualification baseline submissions for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: + +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + + +Example command: + +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json +``` + +## Self-tuning Ruleset + +### JAX + +The prize qualification baseline submissionss for jax are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py` + +Example command: +```bash +python3 submission_runner.py \ + --framework=jax \ + --data_dir= \ + --experiment_dir= \ + --experiment_name= \ + --workload= \ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \ + --tuning_ruleset=self +``` + +### PyTorch + +The prize qualification baseline submissionss for PyTorch are: +- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` +- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py` + +Example command: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ + --framework=pytorch \ + --data_dir= \ + --experiment_dir= \ + --experiment_name=t \ + --workload=\ + --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \ + --tuning_ruleset=self +``` \ No newline at end of file diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..099613fcf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..ef0c11c0d --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,345 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..01cffc52e --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..530dd3acf --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,347 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json new file mode 100644 index 000000000..65562905a --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json @@ -0,0 +1,50 @@ +[ + { + "dropout_rate": 0.0, + "label_smoothing": 0.1, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "label_smoothing": 0.2, + "learning_rate": 0.0008445074561975979, + "one_minus_beta1": 0.11042418465, + "beta2": 0.9978504782314613, + "weight_decay": 0.08135402759553023, + "warmup_factor": 0.05 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.001308209823469072, + "one_minus_beta1": 0.02686663061, + "beta2": 0.9981232922116359, + "weight_decay": 0.16375311233774334, + "warmup_factor": 0.1 + }, + { + "dropout_rate": 0.0, + "learning_rate": 0.004958460849689891, + "one_minus_beta1": 0.13625575743, + "beta2": 0.6291854735396584, + "weight_decay": 0.1147386261512052, + "warmup_factor": 0.02 + }, + { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 + } +] + + + + + + diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py new file mode 100644 index 000000000..c54202e56 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py new file mode 100644 index 000000000..dd42743e2 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -0,0 +1,360 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax.""" + +import functools + +# isort: off +# We have to turn off isort here to resolve a conflict between isort and yapf. +from typing import (Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union) +# isort: on + +import chex +from flax import jax_utils +import jax +from jax import lax +import jax.numpy as jnp +import optax + +from algorithmic_efficiency import spec + +_GRAD_CLIP_EPS = 1e-6 + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py +def nadamw( + learning_rate: Union[float, optax.Schedule], + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], + Any]]] = None, +) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch + implementation also follows this). + Current code implements a simpler version with no momentum decay and slightly + different bias correction terms. The exact description can be found here + https://arxiv.org/pdf/1910.05446.pdf (Table 1). + + Args: + learning_rate: A fixed global scaling factor. + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent with + other frameworks such as PyTorch, but different from (Loshchilov et al, + 2019) where the weight decay is only multiplied with the "schedule + multiplier", but not the base learning rate. + weight_decay_mask: A tree with same structure as (or a prefix of) the params + PyTree, or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Nadam gradient transformations are applied to all parameters. + + Returns: + An (init_fn, update_fn) tuple. + """ + return optax.chain( + scale_by_nadam(b1, b2, eps, eps_root, debias), + optax.add_decayed_weights(weight_decay, weight_decay_mask), + scale_by_learning_rate(learning_rate)) + + +# All functions below are forked from +# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py +def scale_by_nadam(b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + debias: bool = True, + power: float = 0.5) -> optax.GradientTransformation: + """Rescale updates according to the NAdam algorithm. + + References: + There seem to be multiple versions of NAdam. The original version is here + https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also + follows this). + + Current code implements a simpler version with no momentum decay and slightly + different (standard Adam) bias correction terms. The exact description can be + found here https://arxiv.org/pdf/1910.05446.pdf (Table 1) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + debias: Whether to use bias correction. + power: The power to use in the preconditioner (0.5 in default adam). + Returns: + An (init_fn, update_fn) tuple. + """ + raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power) + + def init_fn(params): + mu = jax.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = _update_moment(updates, state.mu, b1, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + mu_hat = _update_moment(updates, mu, b1, 1) + mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count) + nu_hat = nu if not debias else _bias_correction(nu, b2, count) + updates = jax.tree_map( + lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count, mu=mu, nu=nu) + + return optax.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the NAdam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + + +def _update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order-th` moment.""" + return jax.tree_map( + lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments) + + +def _bias_correction(moment, decay, count): + """Perform bias correction. This becomes a no-op as count goes to infinity.""" + beta = 1 - decay**count + return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment) + + +def scale_by_learning_rate(learning_rate, flip_sign=True): + m = -1 if flip_sign else 1 + if callable(learning_rate): + return optax.scale_by_schedule(lambda count: m * learning_rate(count)) + return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + def jax_cosine_warmup(step_hint: int, hyperparameters): + # Create learning rate schedule. + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]) + return schedule_fn + + # Create optimizer + LR schedule. + lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=1.0 - hyperparameters.one_minus_beta1, + b2=hyperparameters.beta2, + eps=1e-8, + weight_decay=hyperparameters.weight_decay) + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn + + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), + static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4)) +def pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + loss = summed_loss / n_valid_examples + grad = jax.tree_map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + outputs = pmapped_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py new file mode 100644 index 000000000..57da48167 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py new file mode 100644 index 000000000..ef6e84c94 --- /dev/null +++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -0,0 +1,362 @@ +"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" + +import math +from typing import Dict, Iterator, List, Tuple + +from absl import logging +import torch +from torch import Tensor +import torch.distributed.nn as dist_nn +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + +from algorithmic_efficiency import spec +from algorithmic_efficiency.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +HPARAMS = { + "dropout_rate": 0.1, + "learning_rate": 0.0017486387539278373, + "one_minus_beta1": 0.06733926164, + "beta2": 0.9955159689799007, + "weight_decay": 0.08121616522670176, + "warmup_factor": 0.02 +} + + +# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. +class NAdamW(torch.optim.Optimizer): + r"""Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + For further details regarding the algorithm we refer to + `Decoupled Weight Decay Regularization`_. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2): + if not 0.0 <= lr: + raise ValueError(f'Invalid learning rate: {lr}') + if not 0.0 <= eps: + raise ValueError(f'Invalid epsilon value: {eps}') + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}') + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}') + if not 0.0 <= weight_decay: + raise ValueError(f'Invalid weight_decay value: {weight_decay}') + defaults = { + 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay + } + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + state_values = list(self.state.values()) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor( + state_values[0]['step']) + if not step_is_tensor: + for s in state_values: + s['step'] = torch.tensor(float(s['step'])) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._cuda_graph_capture_health_check() + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_steps = [] + beta1, beta2 = group['betas'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('NAdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = torch.tensor(0.) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like( + p, memory_format=torch.preserve_format) + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + state_steps.append(state['step']) + + nadamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + state_steps, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps']) + + return loss + + +def nadamw(params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + state_steps: List[Tensor], + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float) -> None: + r"""Functional API that performs NAdamW algorithm computation. + See NAdamW class for details. + """ + + if not all(isinstance(t, torch.Tensor) for t in state_steps): + raise RuntimeError( + 'API has changed, `state_steps` argument must contain a list of' + + ' singleton tensors') + + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + step_t = state_steps[i] + + # Update step. + step_t += 1 + + # Perform stepweight decay. + param.mul_(1 - lr * weight_decay) + + # Decay the first and second moment running average coefficient. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # Only difference between NAdamW and AdamW in this implementation. + # The official PyTorch implementation of NAdam uses a different algorithm. + # We undo these ops later on, which could cause numerical issues but saves + # us from having to make an extra copy of the gradients. + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + step = step_t.item() + + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + + step_size = lr / bias_correction1 + + bias_correction2_sqrt = math.sqrt(bias_correction2) + denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + param.addcdiv_(exp_avg, denom, value=-step_size) + exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(1.0 - hyperparameters.one_minus_beta1, + hyperparameters.beta2), + eps=1e-8, + weight_decay=hyperparameters.weight_decay), + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']) + + return optimizer_state + + +def update_params(workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True) + + label_smoothing = ( + hyperparameters.label_smoothing if hasattr(hyperparameters, + 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, global_step) + logging.info('%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item()) + + return (optimizer_state, current_param_container, new_model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection(workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch