diff --git a/reference_algorithms/prize_qualification_baselines/README.md b/reference_algorithms/prize_qualification_baselines/README.md
new file mode 100644
index 000000000..614f87b32
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/README.md
@@ -0,0 +1,83 @@
+# Prize Qualification Baselines
+This directory contains the baseine(s) that submissions that must beat to qualify for prizes. 
+
+TODO: link back to section in rules. 
+
+## Externally Tuned Ruleset
+
+### JAX
+
+The prize qualification baseline submissions for jax are: 
+- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py`
+- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py`
+
+Example command:
+
+```bash
+python3 submission_runner.py \
+    --framework=jax \
+    --data_dir=<data_dir> \
+    --experiment_dir=<experiment_dir> \
+    --experiment_name=<experiment_name> \
+    --workload=<workload> \
+    --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \
+    --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json
+```
+
+### PyTorch
+
+The prize qualification baseline submissionss for PyTorch are:
+- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` 
+- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py`
+
+
+Example command:
+
+```bash
+torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \
+    --framework=pytorch \
+    --data_dir=<data_dir> \
+    --experiment_dir=<experiment_dir> \
+    --experiment_name=t<experiment_name> \
+    --workload=<workload>\
+    --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \
+    --tuning_search_space=reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json
+```
+
+## Self-tuning Ruleset
+
+### JAX
+
+The prize qualification baseline submissionss for jax are: 
+- `reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py`
+- `feference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py`
+
+Example command:
+```bash
+python3 submission_runner.py \
+    --framework=jax \
+    --data_dir=<data_dir> \
+    --experiment_dir=<experiment_dir> \
+    --experiment_name=<experiment_name> \
+    --workload=<workload> \
+    --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py \
+    --tuning_ruleset=self
+```
+
+### PyTorch
+
+The prize qualification baseline submissionss for PyTorch are:
+- `reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py` 
+- `feference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py`
+
+Example command:
+```bash
+torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \
+    --framework=pytorch \
+    --data_dir=<data_dir> \
+    --experiment_dir=<experiment_dir> \
+    --experiment_name=t<experiment_name> \
+    --workload=<workload>\
+    --submission_path=reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py \
+    --tuning_ruleset=self
+```
\ No newline at end of file
diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
new file mode 100644
index 000000000..099613fcf
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
@@ -0,0 +1,345 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
+
+import functools
+
+# isort: off
+# We have to turn off isort here to resolve a conflict between isort and yapf.
+from typing import (Any,
+                    Callable,
+                    Dict,
+                    Iterator,
+                    List,
+                    NamedTuple,
+                    Optional,
+                    Tuple,
+                    Union)
+# isort: on
+
+import chex
+from flax import jax_utils
+import jax
+from jax import lax
+import jax.numpy as jnp
+import optax
+
+from algorithmic_efficiency import spec
+
+_GRAD_CLIP_EPS = 1e-6
+
+
+# Forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
+def nadamw(
+    learning_rate: Union[float, optax.Schedule],
+    b1: float = 0.9,
+    b2: float = 0.999,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    debias: bool = True,
+    weight_decay: float = 0.0,
+    weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
+                                                    Any]]] = None,
+) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch
+  implementation also follows this).
+  Current code implements a simpler version with no momentum decay and slightly
+  different bias correction terms. The exact description can be found here
+  https://arxiv.org/pdf/1910.05446.pdf (Table 1).
+
+  Args:
+    learning_rate: A fixed global scaling factor.
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    weight_decay: Strength of the weight decay regularization. Note that this
+      weight decay is multiplied with the learning rate. This is consistent with
+      other frameworks such as PyTorch, but different from (Loshchilov et al,
+      2019) where the weight decay is only multiplied with the "schedule
+      multiplier", but not the base learning rate.
+    weight_decay_mask: A tree with same structure as (or a prefix of) the params
+      PyTree, or a Callable that returns such a pytree given the params/updates.
+      The leaves should be booleans, `True` for leaves/subtrees you want to
+      apply the weight decay to, and `False` for those you want to skip. Note
+      that the Nadam gradient transformations are applied to all parameters.
+
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  return optax.chain(
+      scale_by_nadam(b1, b2, eps, eps_root, debias),
+      optax.add_decayed_weights(weight_decay, weight_decay_mask),
+      scale_by_learning_rate(learning_rate))
+
+
+# All functions below are forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
+def scale_by_nadam(b1: float = 0.9,
+                   b2: float = 0.999,
+                   eps: float = 1e-8,
+                   eps_root: float = 0.0,
+                   debias: bool = True,
+                   power: float = 0.5) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also
+  follows this).
+
+  Current code implements a simpler version with no momentum decay and slightly
+  different (standard Adam) bias correction terms. The exact description can be
+  found here https://arxiv.org/pdf/1910.05446.pdf (Table 1)
+
+  Args:
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    power: The power to use in the preconditioner (0.5 in default adam).
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)
+
+  def init_fn(params):
+    mu = jax.tree_map(jnp.zeros_like, params)  # First moment
+    nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
+    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
+
+  def update_fn(updates, state, params=None):
+    del params
+    mu = _update_moment(updates, state.mu, b1, 1)
+    nu = _update_moment(updates, state.nu, b2, 2)
+    count = state.count + jnp.array(1, dtype=jnp.int32)
+    mu_hat = _update_moment(updates, mu, b1, 1)
+    mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
+    nu_hat = nu if not debias else _bias_correction(nu, b2, count)
+    updates = jax.tree_map(
+        lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+    return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
+
+  return optax.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByAdamState(NamedTuple):
+  """State for the NAdam algorithm."""
+  count: chex.Array  # shape=(), dtype=jnp.int32.
+  mu: optax.Updates
+  nu: optax.Updates
+
+
+def _update_moment(updates, moments, decay, order):
+  """Compute the exponential moving average of the `order-th` moment."""
+  return jax.tree_map(
+      lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+
+
+def _bias_correction(moment, decay, count):
+  """Perform bias correction. This becomes a no-op as count goes to infinity."""
+  beta = 1 - decay**count
+  return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
+
+
+def scale_by_learning_rate(learning_rate, flip_sign=True):
+  m = -1 if flip_sign else 1
+  if callable(learning_rate):
+    return optax.scale_by_schedule(lambda count: m * learning_rate(count))
+  return optax.scale(m * learning_rate)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_params
+  del model_state
+  del rng
+
+  def jax_cosine_warmup(step_hint: int, hyperparameters):
+    # Create learning rate schedule.
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup_fn = optax.linear_schedule(
+        init_value=0.,
+        end_value=hyperparameters.learning_rate,
+        transition_steps=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(
+        init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+    return schedule_fn
+
+  # Create optimizer + LR schedule.
+  lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
+  opt_init_fn, opt_update_fn = nadamw(
+      learning_rate=lr_schedule_fn,
+      b1=1.0 - hyperparameters.one_minus_beta1,
+      b2=hyperparameters.beta2,
+      eps=1e-8,
+      weight_decay=hyperparameters.weight_decay)
+  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
+                                   workload.param_shapes)
+  optimizer_state = opt_init_fn(params_zeros_like)
+
+  return jax_utils.replicate(optimizer_state), opt_update_fn
+
+
+@functools.partial(
+    jax.pmap,
+    axis_name='batch',
+    in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+    static_broadcasted_argnums=(0, 1),
+    donate_argnums=(2, 3, 4))
+def pmapped_train_step(workload,
+                       opt_update_fn,
+                       model_state,
+                       optimizer_state,
+                       current_param_container,
+                       batch,
+                       rng,
+                       grad_clip,
+                       label_smoothing):
+
+  def _loss_fn(params):
+    """Loss function used for training."""
+    logits, new_model_state = workload.model_fn(
+        params,
+        batch,
+        model_state,
+        spec.ForwardPassMode.TRAIN,
+        rng,
+        update_batch_norm=True)
+    loss_dict = workload.loss_fn(
+        label_batch=batch['targets'],
+        logits_batch=logits,
+        mask_batch=batch.get('weights'),
+        label_smoothing=label_smoothing)
+    summed_loss = loss_dict['summed']
+    n_valid_examples = loss_dict['n_valid_examples']
+    return summed_loss, (n_valid_examples, new_model_state)
+
+  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+  (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
+      current_param_container)
+  # Get correct global mean loss and grad.
+  (summed_loss, n_valid_examples, grad) = lax.psum(
+      (summed_loss, n_valid_examples, grad), axis_name='batch')
+  loss = summed_loss / n_valid_examples
+  grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
+
+  grad_norm = jnp.sqrt(
+      sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+
+  if grad_clip is not None:
+    grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
+    grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
+    grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
+
+  updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
+                                               current_param_container)
+  updated_params = optax.apply_updates(current_param_container, updates)
+  return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+
+  optimizer_state, opt_update_fn = optimizer_state
+  per_device_rngs = jax.random.split(rng, jax.local_device_count())
+  if hasattr(hyperparameters, 'label_smoothing'):
+    label_smoothing = hyperparameters.label_smoothing
+  else:
+    label_smoothing = 0.0
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+  outputs = pmapped_train_step(workload,
+                               opt_update_fn,
+                               model_state,
+                               optimizer_state,
+                               current_param_container,
+                               batch,
+                               per_device_rngs,
+                               grad_clip,
+                               label_smoothing)
+  new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
+
+  # Log loss, grad_norm.
+  if global_step % 100 == 0 and workload.metrics_logger is not None:
+    workload.metrics_logger.append_scalar_metrics(
+        {
+            'loss': loss[0],
+            'grad_norm': grad_norm[0],
+        }, global_step)
+  return (new_optimizer_state, opt_update_fn), new_params, new_model_state
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
new file mode 100644
index 000000000..ef0c11c0d
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -0,0 +1,345 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
+
+import functools
+
+# isort: off
+# We have to turn off isort here to resolve a conflict between isort and yapf.
+from typing import (Any,
+                    Callable,
+                    Dict,
+                    Iterator,
+                    List,
+                    NamedTuple,
+                    Optional,
+                    Tuple,
+                    Union)
+# isort: on
+
+import chex
+from flax import jax_utils
+import jax
+from jax import lax
+import jax.numpy as jnp
+import optax
+
+from algorithmic_efficiency import spec
+
+_GRAD_CLIP_EPS = 1e-6
+
+
+# Forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
+def nadamw(
+    learning_rate: Union[float, optax.Schedule],
+    b1: float = 0.9,
+    b2: float = 0.999,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    debias: bool = True,
+    weight_decay: float = 0.0,
+    weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
+                                                    Any]]] = None,
+) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch
+  implementation also follows this).
+  Current code implements a simpler version with no momentum decay and slightly
+  different bias correction terms. The exact description can be found here
+  https://arxiv.org/pdf/1910.05446.pdf (Table 1).
+
+  Args:
+    learning_rate: A fixed global scaling factor.
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    weight_decay: Strength of the weight decay regularization. Note that this
+      weight decay is multiplied with the learning rate. This is consistent with
+      other frameworks such as PyTorch, but different from (Loshchilov et al,
+      2019) where the weight decay is only multiplied with the "schedule
+      multiplier", but not the base learning rate.
+    weight_decay_mask: A tree with same structure as (or a prefix of) the params
+      PyTree, or a Callable that returns such a pytree given the params/updates.
+      The leaves should be booleans, `True` for leaves/subtrees you want to
+      apply the weight decay to, and `False` for those you want to skip. Note
+      that the Nadam gradient transformations are applied to all parameters.
+
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  return optax.chain(
+      scale_by_nadam(b1, b2, eps, eps_root, debias),
+      optax.add_decayed_weights(weight_decay, weight_decay_mask),
+      scale_by_learning_rate(learning_rate))
+
+
+# All functions below are forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
+def scale_by_nadam(b1: float = 0.9,
+                   b2: float = 0.999,
+                   eps: float = 1e-8,
+                   eps_root: float = 0.0,
+                   debias: bool = True,
+                   power: float = 0.5) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also
+  follows this).
+
+  Current code implements a simpler version with no momentum decay and slightly
+  different (standard Adam) bias correction terms. The exact description can be
+  found here https://arxiv.org/pdf/1910.05446.pdf (Table 1)
+
+  Args:
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    power: The power to use in the preconditioner (0.5 in default adam).
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)
+
+  def init_fn(params):
+    mu = jax.tree_map(jnp.zeros_like, params)  # First moment
+    nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
+    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
+
+  def update_fn(updates, state, params=None):
+    del params
+    mu = _update_moment(updates, state.mu, b1, 1)
+    nu = _update_moment(updates, state.nu, b2, 2)
+    count = state.count + jnp.array(1, dtype=jnp.int32)
+    mu_hat = _update_moment(updates, mu, b1, 1)
+    mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
+    nu_hat = nu if not debias else _bias_correction(nu, b2, count)
+    updates = jax.tree_map(
+        lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+    return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
+
+  return optax.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByAdamState(NamedTuple):
+  """State for the NAdam algorithm."""
+  count: chex.Array  # shape=(), dtype=jnp.int32.
+  mu: optax.Updates
+  nu: optax.Updates
+
+
+def _update_moment(updates, moments, decay, order):
+  """Compute the exponential moving average of the `order-th` moment."""
+  return jax.tree_map(
+      lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+
+
+def _bias_correction(moment, decay, count):
+  """Perform bias correction. This becomes a no-op as count goes to infinity."""
+  beta = 1 - decay**count
+  return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
+
+
+def scale_by_learning_rate(learning_rate, flip_sign=True):
+  m = -1 if flip_sign else 1
+  if callable(learning_rate):
+    return optax.scale_by_schedule(lambda count: m * learning_rate(count))
+  return optax.scale(m * learning_rate)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_params
+  del model_state
+  del rng
+
+  def jax_cosine_warmup(step_hint: int, hyperparameters):
+    # Create learning rate schedule.
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup_fn = optax.linear_schedule(
+        init_value=0.,
+        end_value=hyperparameters.learning_rate,
+        transition_steps=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(
+        init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+    return schedule_fn
+
+  # Create optimizer + LR schedule.
+  lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
+  opt_init_fn, opt_update_fn = nadamw(
+      learning_rate=lr_schedule_fn,
+      b1=1.0 - hyperparameters.one_minus_beta1,
+      b2=hyperparameters.beta2,
+      eps=1e-8,
+      weight_decay=hyperparameters.weight_decay)
+  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
+                                   workload.param_shapes)
+  optimizer_state = opt_init_fn(params_zeros_like)
+
+  return jax_utils.replicate(optimizer_state), opt_update_fn
+
+
+@functools.partial(
+    jax.pmap,
+    axis_name='batch',
+    in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+    static_broadcasted_argnums=(0, 1),
+    donate_argnums=(2, 3, 4))
+def pmapped_train_step(workload,
+                       opt_update_fn,
+                       model_state,
+                       optimizer_state,
+                       current_param_container,
+                       batch,
+                       rng,
+                       grad_clip,
+                       label_smoothing):
+
+  def _loss_fn(params):
+    """Loss function used for training."""
+    logits, new_model_state = workload.model_fn(
+        params,
+        batch,
+        model_state,
+        spec.ForwardPassMode.TRAIN,
+        rng,
+        update_batch_norm=True)
+    loss_dict = workload.loss_fn(
+        label_batch=batch['targets'],
+        logits_batch=logits,
+        mask_batch=batch.get('weights'),
+        label_smoothing=label_smoothing)
+    summed_loss = loss_dict['summed']
+    n_valid_examples = loss_dict['n_valid_examples']
+    return summed_loss, (n_valid_examples, new_model_state)
+
+  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+  (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
+      current_param_container)
+  # Get correct global mean loss and grad.
+  (summed_loss, n_valid_examples, grad) = lax.psum(
+      (summed_loss, n_valid_examples, grad), axis_name='batch')
+  loss = summed_loss / n_valid_examples
+  grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
+
+  grad_norm = jnp.sqrt(
+      sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+
+  if grad_clip is not None:
+    grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
+    grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
+    grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
+
+  updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
+                                               current_param_container)
+  updated_params = optax.apply_updates(current_param_container, updates)
+  return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+
+  optimizer_state, opt_update_fn = optimizer_state
+  per_device_rngs = jax.random.split(rng, jax.local_device_count())
+  if hasattr(hyperparameters, 'label_smoothing'):
+    label_smoothing = hyperparameters.label_smoothing
+  else:
+    label_smoothing = 0.0
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+  outputs = pmapped_train_step(workload,
+                               opt_update_fn,
+                               model_state,
+                               optimizer_state,
+                               current_param_container,
+                               batch,
+                               per_device_rngs,
+                               grad_clip,
+                               label_smoothing)
+  new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
+
+  # Log loss, grad_norm.
+  if global_step % 100 == 0 and workload.metrics_logger is not None:
+    workload.metrics_logger.append_scalar_metrics(
+        {
+            'loss': loss[0],
+            'grad_norm': grad_norm[0],
+        }, global_step)
+  return (new_optimizer_state, opt_update_fn), new_params, new_model_state
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
new file mode 100644
index 000000000..01cffc52e
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
@@ -0,0 +1,347 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
+
+import math
+from typing import Dict, Iterator, List, Tuple
+
+from absl import logging
+import torch
+from torch import Tensor
+import torch.distributed.nn as dist_nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.lr_scheduler import SequentialLR
+
+from algorithmic_efficiency import spec
+from algorithmic_efficiency.pytorch_utils import pytorch_setup
+
+USE_PYTORCH_DDP = pytorch_setup()[0]
+
+
+# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
+class NAdamW(torch.optim.Optimizer):
+  r"""Implements NAdamW algorithm.
+
+    See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+    the NAdam algorithm (there is also a comment in the code which highlights
+    the only difference of NAdamW and AdamW).
+    For further details regarding the algorithm we refer to
+    `Decoupled Weight Decay Regularization`_.
+
+    Args:
+      params (iterable): iterable of parameters to optimize or dicts defining
+          parameter groups
+      lr (float, optional): learning rate (default: 1e-3)
+      betas (Tuple[float, float], optional): coefficients used for computing
+          running averages of gradient and its square (default: (0.9, 0.999))
+      eps (float, optional): term added to the denominator to improve
+          numerical stability (default: 1e-8)
+      weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+  """
+
+  def __init__(self,
+               params,
+               lr=1e-3,
+               betas=(0.9, 0.999),
+               eps=1e-8,
+               weight_decay=1e-2):
+    if not 0.0 <= lr:
+      raise ValueError(f'Invalid learning rate: {lr}')
+    if not 0.0 <= eps:
+      raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= betas[0] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}')
+    if not 0.0 <= betas[1] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
+    if not 0.0 <= weight_decay:
+      raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+    defaults = {
+        'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+    }
+    super().__init__(params, defaults)
+
+  def __setstate__(self, state):
+    super().__setstate__(state)
+    state_values = list(self.state.values())
+    step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
+        state_values[0]['step'])
+    if not step_is_tensor:
+      for s in state_values:
+        s['step'] = torch.tensor(float(s['step']))
+
+  @torch.no_grad()
+  def step(self, closure=None):
+    """Performs a single optimization step.
+
+        Args:
+          closure (callable, optional): A closure that reevaluates the model
+              and returns the loss.
+    """
+    self._cuda_graph_capture_health_check()
+
+    loss = None
+    if closure is not None:
+      with torch.enable_grad():
+        loss = closure()
+
+    for group in self.param_groups:
+      params_with_grad = []
+      grads = []
+      exp_avgs = []
+      exp_avg_sqs = []
+      state_steps = []
+      beta1, beta2 = group['betas']
+
+      for p in group['params']:
+        if p.grad is None:
+          continue
+        params_with_grad.append(p)
+        if p.grad.is_sparse:
+          raise RuntimeError('NAdamW does not support sparse gradients')
+        grads.append(p.grad)
+
+        state = self.state[p]
+
+        # State initialization
+        if len(state) == 0:
+          state['step'] = torch.tensor(0.)
+          # Exponential moving average of gradient values
+          state['exp_avg'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+          # Exponential moving average of squared gradient values
+          state['exp_avg_sq'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+
+        exp_avgs.append(state['exp_avg'])
+        exp_avg_sqs.append(state['exp_avg_sq'])
+        state_steps.append(state['step'])
+
+      nadamw(
+          params_with_grad,
+          grads,
+          exp_avgs,
+          exp_avg_sqs,
+          state_steps,
+          beta1=beta1,
+          beta2=beta2,
+          lr=group['lr'],
+          weight_decay=group['weight_decay'],
+          eps=group['eps'])
+
+    return loss
+
+
+def nadamw(params: List[Tensor],
+           grads: List[Tensor],
+           exp_avgs: List[Tensor],
+           exp_avg_sqs: List[Tensor],
+           state_steps: List[Tensor],
+           beta1: float,
+           beta2: float,
+           lr: float,
+           weight_decay: float,
+           eps: float) -> None:
+  r"""Functional API that performs NAdamW algorithm computation.
+    See NAdamW class for details.
+  """
+
+  if not all(isinstance(t, torch.Tensor) for t in state_steps):
+    raise RuntimeError(
+        'API has changed, `state_steps` argument must contain a list of' +
+        ' singleton tensors')
+
+  for i, param in enumerate(params):
+    grad = grads[i]
+    exp_avg = exp_avgs[i]
+    exp_avg_sq = exp_avg_sqs[i]
+    step_t = state_steps[i]
+
+    # Update step.
+    step_t += 1
+
+    # Perform stepweight decay.
+    param.mul_(1 - lr * weight_decay)
+
+    # Decay the first and second moment running average coefficient.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+    # Only difference between NAdamW and AdamW in this implementation.
+    # The official PyTorch implementation of NAdam uses a different algorithm.
+    # We undo these ops later on, which could cause numerical issues but saves
+    # us from having to make an extra copy of the gradients.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+    step = step_t.item()
+
+    bias_correction1 = 1 - beta1**step
+    bias_correction2 = 1 - beta2**step
+
+    step_size = lr / bias_correction1
+
+    bias_correction2_sqrt = math.sqrt(bias_correction2)
+    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+
+    param.addcdiv_(exp_avg, denom, value=-step_size)
+    exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_state
+  del rng
+
+  optimizer_state = {
+      'optimizer':
+          NAdamW(
+              model_params.parameters(),
+              lr=hyperparameters.learning_rate,
+              betas=(1.0 - hyperparameters.one_minus_beta1,
+                     hyperparameters.beta2),
+              eps=1e-8,
+              weight_decay=hyperparameters.weight_decay),
+  }
+
+  def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup = LinearLR(
+        optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
+    return SequentialLR(
+        optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+
+  optimizer_state['scheduler'] = pytorch_cosine_warmup(
+      workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+
+  return optimizer_state
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+
+  current_model = current_param_container
+  current_model.train()
+  optimizer_state['optimizer'].zero_grad()
+
+  logits_batch, new_model_state = workload.model_fn(
+      params=current_model,
+      augmented_and_preprocessed_input_batch=batch,
+      model_state=model_state,
+      mode=spec.ForwardPassMode.TRAIN,
+      rng=rng,
+      update_batch_norm=True)
+
+  label_smoothing = (
+      hyperparameters.label_smoothing if hasattr(hyperparameters,
+                                                 'label_smoothing') else 0.0)
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+
+  loss_dict = workload.loss_fn(
+      label_batch=batch['targets'],
+      logits_batch=logits_batch,
+      mask_batch=batch.get('weights'),
+      label_smoothing=label_smoothing)
+  summed_loss = loss_dict['summed']
+  n_valid_examples = loss_dict['n_valid_examples']
+  if USE_PYTORCH_DDP:
+    # Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
+    summed_loss = dist_nn.all_reduce(summed_loss)
+    n_valid_examples = dist_nn.all_reduce(n_valid_examples)
+  loss = summed_loss / n_valid_examples
+
+  loss.backward()
+
+  if grad_clip is not None:
+    torch.nn.utils.clip_grad_norm_(
+        current_model.parameters(), max_norm=grad_clip)
+  optimizer_state['optimizer'].step()
+  optimizer_state['scheduler'].step()
+
+  # Log training metrics - loss, grad_norm, batch_size.
+  if global_step <= 100 or global_step % 500 == 0:
+    with torch.no_grad():
+      parameters = [p for p in current_model.parameters() if p.grad is not None]
+      grad_norm = torch.norm(
+          torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+    if workload.metrics_logger is not None:
+      workload.metrics_logger.append_scalar_metrics(
+          {
+              'loss': loss.item(),
+              'grad_norm': grad_norm.item(),
+          }, global_step)
+    logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
+                 global_step,
+                 loss.item(),
+                 grad_norm.item())
+
+  return (optimizer_state, current_param_container, new_model_state)
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
new file mode 100644
index 000000000..530dd3acf
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
@@ -0,0 +1,347 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
+
+import math
+from typing import Dict, Iterator, List, Tuple
+
+from absl import logging
+import torch
+from torch import Tensor
+import torch.distributed.nn as dist_nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.lr_scheduler import SequentialLR
+
+from algorithmic_efficiency import spec
+from algorithmic_efficiency.pytorch_utils import pytorch_setup
+
+USE_PYTORCH_DDP = pytorch_setup()[0]
+
+
+# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
+class NAdamW(torch.optim.Optimizer):
+  r"""Implements NAdamW algorithm.
+
+    See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+    the NAdam algorithm (there is also a comment in the code which highlights
+    the only difference of NAdamW and AdamW).
+    For further details regarding the algorithm we refer to
+    `Decoupled Weight Decay Regularization`_.
+
+    Args:
+      params (iterable): iterable of parameters to optimize or dicts defining
+          parameter groups
+      lr (float, optional): learning rate (default: 1e-3)
+      betas (Tuple[float, float], optional): coefficients used for computing
+          running averages of gradient and its square (default: (0.9, 0.999))
+      eps (float, optional): term added to the denominator to improve
+          numerical stability (default: 1e-8)
+      weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+  """
+
+  def __init__(self,
+               params,
+               lr=1e-3,
+               betas=(0.9, 0.999),
+               eps=1e-8,
+               weight_decay=1e-2):
+    if not 0.0 <= lr:
+      raise ValueError(f'Invalid learning rate: {lr}')
+    if not 0.0 <= eps:
+      raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= betas[0] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}')
+    if not 0.0 <= betas[1] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
+    if not 0.0 <= weight_decay:
+      raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+    defaults = {
+        'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+    }
+    super().__init__(params, defaults)
+
+  def __setstate__(self, state):
+    super().__setstate__(state)
+    state_values = list(self.state.values())
+    step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
+        state_values[0]['step'])
+    if not step_is_tensor:
+      for s in state_values:
+        s['step'] = torch.tensor(float(s['step']))
+
+  @torch.no_grad()
+  def step(self, closure=None):
+    """Performs a single optimization step.
+
+        Args:
+          closure (callable, optional): A closure that reevaluates the model
+              and returns the loss.
+    """
+    self._cuda_graph_capture_health_check()
+
+    loss = None
+    if closure is not None:
+      with torch.enable_grad():
+        loss = closure()
+
+    for group in self.param_groups:
+      params_with_grad = []
+      grads = []
+      exp_avgs = []
+      exp_avg_sqs = []
+      state_steps = []
+      beta1, beta2 = group['betas']
+
+      for p in group['params']:
+        if p.grad is None:
+          continue
+        params_with_grad.append(p)
+        if p.grad.is_sparse:
+          raise RuntimeError('NAdamW does not support sparse gradients')
+        grads.append(p.grad)
+
+        state = self.state[p]
+
+        # State initialization
+        if len(state) == 0:
+          state['step'] = torch.tensor(0.)
+          # Exponential moving average of gradient values
+          state['exp_avg'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+          # Exponential moving average of squared gradient values
+          state['exp_avg_sq'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+
+        exp_avgs.append(state['exp_avg'])
+        exp_avg_sqs.append(state['exp_avg_sq'])
+        state_steps.append(state['step'])
+
+      nadamw(
+          params_with_grad,
+          grads,
+          exp_avgs,
+          exp_avg_sqs,
+          state_steps,
+          beta1=beta1,
+          beta2=beta2,
+          lr=group['lr'],
+          weight_decay=group['weight_decay'],
+          eps=group['eps'])
+
+    return loss
+
+
+def nadamw(params: List[Tensor],
+           grads: List[Tensor],
+           exp_avgs: List[Tensor],
+           exp_avg_sqs: List[Tensor],
+           state_steps: List[Tensor],
+           beta1: float,
+           beta2: float,
+           lr: float,
+           weight_decay: float,
+           eps: float) -> None:
+  r"""Functional API that performs NAdamW algorithm computation.
+    See NAdamW class for details.
+  """
+
+  if not all(isinstance(t, torch.Tensor) for t in state_steps):
+    raise RuntimeError(
+        'API has changed, `state_steps` argument must contain a list of' +
+        ' singleton tensors')
+
+  for i, param in enumerate(params):
+    grad = grads[i]
+    exp_avg = exp_avgs[i]
+    exp_avg_sq = exp_avg_sqs[i]
+    step_t = state_steps[i]
+
+    # Update step.
+    step_t += 1
+
+    # Perform stepweight decay.
+    param.mul_(1 - lr * weight_decay)
+
+    # Decay the first and second moment running average coefficient.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+    # Only difference between NAdamW and AdamW in this implementation.
+    # The official PyTorch implementation of NAdam uses a different algorithm.
+    # We undo these ops later on, which could cause numerical issues but saves
+    # us from having to make an extra copy of the gradients.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+    step = step_t.item()
+
+    bias_correction1 = 1 - beta1**step
+    bias_correction2 = 1 - beta2**step
+
+    step_size = lr / bias_correction1
+
+    bias_correction2_sqrt = math.sqrt(bias_correction2)
+    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+
+    param.addcdiv_(exp_avg, denom, value=-step_size)
+    exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_state
+  del rng
+
+  optimizer_state = {
+      'optimizer':
+          NAdamW(
+              model_params.parameters(),
+              lr=hyperparameters.learning_rate,
+              betas=(1.0 - hyperparameters.one_minus_beta1,
+                     hyperparameters.beta2),
+              eps=1e-8,
+              weight_decay=hyperparameters.weight_decay),
+  }
+
+  def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup = LinearLR(
+        optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
+    return SequentialLR(
+        optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+
+  optimizer_state['scheduler'] = pytorch_cosine_warmup(
+      workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+
+  return optimizer_state
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+
+  current_model = current_param_container
+  current_model.train()
+  optimizer_state['optimizer'].zero_grad()
+
+  logits_batch, new_model_state = workload.model_fn(
+      params=current_model,
+      augmented_and_preprocessed_input_batch=batch,
+      model_state=model_state,
+      mode=spec.ForwardPassMode.TRAIN,
+      rng=rng,
+      update_batch_norm=True)
+
+  label_smoothing = (
+      hyperparameters.label_smoothing if hasattr(hyperparameters,
+                                                 'label_smoothing') else 0.0)
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+
+  loss_dict = workload.loss_fn(
+      label_batch=batch['targets'],
+      logits_batch=logits_batch,
+      mask_batch=batch.get('weights'),
+      label_smoothing=label_smoothing)
+  summed_loss = loss_dict['summed']
+  n_valid_examples = loss_dict['n_valid_examples']
+  if USE_PYTORCH_DDP:
+    # Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
+    summed_loss = dist_nn.all_reduce(summed_loss)
+    n_valid_examples = dist_nn.all_reduce(n_valid_examples)
+  loss = summed_loss / n_valid_examples
+
+  loss.backward()
+
+  if grad_clip is not None:
+    torch.nn.utils.clip_grad_norm_(
+        current_model.parameters(), max_norm=grad_clip)
+  optimizer_state['optimizer'].step()
+  optimizer_state['scheduler'].step()
+
+  # Log training metrics - loss, grad_norm, batch_size.
+  if global_step <= 100 or global_step % 500 == 0:
+    with torch.no_grad():
+      parameters = [p for p in current_model.parameters() if p.grad is not None]
+      grad_norm = torch.norm(
+          torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+    if workload.metrics_logger is not None:
+      workload.metrics_logger.append_scalar_metrics(
+          {
+              'loss': loss.item(),
+              'grad_norm': grad_norm.item(),
+          }, global_step)
+    logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
+                 global_step,
+                 loss.item(),
+                 grad_norm.item())
+
+  return (optimizer_state, current_param_container, new_model_state)
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json
new file mode 100644
index 000000000..65562905a
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/external_tuning/tuning_search_space.json
@@ -0,0 +1,50 @@
+[
+    {
+        "dropout_rate": 0.0,
+        "label_smoothing": 0.1,
+        "learning_rate": 0.001308209823469072,
+        "one_minus_beta1": 0.02686663061,
+        "beta2": 0.9981232922116359,
+        "weight_decay": 0.16375311233774334,
+        "warmup_factor": 0.1
+    },
+    {
+        "dropout_rate": 0.0,
+        "label_smoothing": 0.2,
+        "learning_rate": 0.0008445074561975979,
+        "one_minus_beta1": 0.11042418465,
+        "beta2": 0.9978504782314613,
+        "weight_decay": 0.08135402759553023,
+        "warmup_factor": 0.05
+    },
+    {
+        "dropout_rate": 0.0,
+        "learning_rate": 0.001308209823469072,
+        "one_minus_beta1": 0.02686663061,
+        "beta2": 0.9981232922116359,
+        "weight_decay": 0.16375311233774334,
+        "warmup_factor": 0.1
+    },
+    {
+        "dropout_rate": 0.0,
+        "learning_rate": 0.004958460849689891,
+        "one_minus_beta1": 0.13625575743,
+        "beta2": 0.6291854735396584,
+        "weight_decay": 0.1147386261512052,
+        "warmup_factor": 0.02
+    },
+    {
+        "dropout_rate": 0.1,
+        "learning_rate": 0.0017486387539278373,
+        "one_minus_beta1": 0.06733926164,
+        "beta2": 0.9955159689799007,
+        "weight_decay": 0.08121616522670176,
+        "warmup_factor": 0.02
+    }
+]
+
+
+
+
+
+
diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
new file mode 100644
index 000000000..c54202e56
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
@@ -0,0 +1,360 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
+
+import functools
+
+# isort: off
+# We have to turn off isort here to resolve a conflict between isort and yapf.
+from typing import (Any,
+                    Callable,
+                    Dict,
+                    Iterator,
+                    List,
+                    NamedTuple,
+                    Optional,
+                    Tuple,
+                    Union)
+# isort: on
+
+import chex
+from flax import jax_utils
+import jax
+from jax import lax
+import jax.numpy as jnp
+import optax
+
+from algorithmic_efficiency import spec
+
+_GRAD_CLIP_EPS = 1e-6
+
+HPARAMS = {
+    "dropout_rate": 0.1,
+    "learning_rate": 0.0017486387539278373,
+    "one_minus_beta1": 0.06733926164,
+    "beta2": 0.9955159689799007,
+    "weight_decay": 0.08121616522670176,
+    "warmup_factor": 0.02
+}
+
+
+# Forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
+def nadamw(
+    learning_rate: Union[float, optax.Schedule],
+    b1: float = 0.9,
+    b2: float = 0.999,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    debias: bool = True,
+    weight_decay: float = 0.0,
+    weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
+                                                    Any]]] = None,
+) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch
+  implementation also follows this).
+  Current code implements a simpler version with no momentum decay and slightly
+  different bias correction terms. The exact description can be found here
+  https://arxiv.org/pdf/1910.05446.pdf (Table 1).
+
+  Args:
+    learning_rate: A fixed global scaling factor.
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    weight_decay: Strength of the weight decay regularization. Note that this
+      weight decay is multiplied with the learning rate. This is consistent with
+      other frameworks such as PyTorch, but different from (Loshchilov et al,
+      2019) where the weight decay is only multiplied with the "schedule
+      multiplier", but not the base learning rate.
+    weight_decay_mask: A tree with same structure as (or a prefix of) the params
+      PyTree, or a Callable that returns such a pytree given the params/updates.
+      The leaves should be booleans, `True` for leaves/subtrees you want to
+      apply the weight decay to, and `False` for those you want to skip. Note
+      that the Nadam gradient transformations are applied to all parameters.
+
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  return optax.chain(
+      scale_by_nadam(b1, b2, eps, eps_root, debias),
+      optax.add_decayed_weights(weight_decay, weight_decay_mask),
+      scale_by_learning_rate(learning_rate))
+
+
+# All functions below are forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
+def scale_by_nadam(b1: float = 0.9,
+                   b2: float = 0.999,
+                   eps: float = 1e-8,
+                   eps_root: float = 0.0,
+                   debias: bool = True,
+                   power: float = 0.5) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also
+  follows this).
+
+  Current code implements a simpler version with no momentum decay and slightly
+  different (standard Adam) bias correction terms. The exact description can be
+  found here https://arxiv.org/pdf/1910.05446.pdf (Table 1)
+
+  Args:
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    power: The power to use in the preconditioner (0.5 in default adam).
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)
+
+  def init_fn(params):
+    mu = jax.tree_map(jnp.zeros_like, params)  # First moment
+    nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
+    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
+
+  def update_fn(updates, state, params=None):
+    del params
+    mu = _update_moment(updates, state.mu, b1, 1)
+    nu = _update_moment(updates, state.nu, b2, 2)
+    count = state.count + jnp.array(1, dtype=jnp.int32)
+    mu_hat = _update_moment(updates, mu, b1, 1)
+    mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
+    nu_hat = nu if not debias else _bias_correction(nu, b2, count)
+    updates = jax.tree_map(
+        lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+    return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
+
+  return optax.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByAdamState(NamedTuple):
+  """State for the NAdam algorithm."""
+  count: chex.Array  # shape=(), dtype=jnp.int32.
+  mu: optax.Updates
+  nu: optax.Updates
+
+
+def _update_moment(updates, moments, decay, order):
+  """Compute the exponential moving average of the `order-th` moment."""
+  return jax.tree_map(
+      lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+
+
+def _bias_correction(moment, decay, count):
+  """Perform bias correction. This becomes a no-op as count goes to infinity."""
+  beta = 1 - decay**count
+  return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
+
+
+def scale_by_learning_rate(learning_rate, flip_sign=True):
+  m = -1 if flip_sign else 1
+  if callable(learning_rate):
+    return optax.scale_by_schedule(lambda count: m * learning_rate(count))
+  return optax.scale(m * learning_rate)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_params
+  del model_state
+  del rng
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  def jax_cosine_warmup(step_hint: int, hyperparameters):
+    # Create learning rate schedule.
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup_fn = optax.linear_schedule(
+        init_value=0.,
+        end_value=hyperparameters.learning_rate,
+        transition_steps=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(
+        init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+    return schedule_fn
+
+  # Create optimizer + LR schedule.
+  lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
+  opt_init_fn, opt_update_fn = nadamw(
+      learning_rate=lr_schedule_fn,
+      b1=1.0 - hyperparameters.one_minus_beta1,
+      b2=hyperparameters.beta2,
+      eps=1e-8,
+      weight_decay=hyperparameters.weight_decay)
+  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
+                                   workload.param_shapes)
+  optimizer_state = opt_init_fn(params_zeros_like)
+
+  return jax_utils.replicate(optimizer_state), opt_update_fn
+
+
+@functools.partial(
+    jax.pmap,
+    axis_name='batch',
+    in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+    static_broadcasted_argnums=(0, 1),
+    donate_argnums=(2, 3, 4))
+def pmapped_train_step(workload,
+                       opt_update_fn,
+                       model_state,
+                       optimizer_state,
+                       current_param_container,
+                       batch,
+                       rng,
+                       grad_clip,
+                       label_smoothing):
+
+  def _loss_fn(params):
+    """Loss function used for training."""
+    logits, new_model_state = workload.model_fn(
+        params,
+        batch,
+        model_state,
+        spec.ForwardPassMode.TRAIN,
+        rng,
+        update_batch_norm=True)
+    loss_dict = workload.loss_fn(
+        label_batch=batch['targets'],
+        logits_batch=logits,
+        mask_batch=batch.get('weights'),
+        label_smoothing=label_smoothing)
+    summed_loss = loss_dict['summed']
+    n_valid_examples = loss_dict['n_valid_examples']
+    return summed_loss, (n_valid_examples, new_model_state)
+
+  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+  (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
+      current_param_container)
+  # Get correct global mean loss and grad.
+  (summed_loss, n_valid_examples, grad) = lax.psum(
+      (summed_loss, n_valid_examples, grad), axis_name='batch')
+  loss = summed_loss / n_valid_examples
+  grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
+
+  grad_norm = jnp.sqrt(
+      sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+
+  if grad_clip is not None:
+    grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
+    grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
+    grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
+
+  updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
+                                               current_param_container)
+  updated_params = optax.apply_updates(current_param_container, updates)
+  return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  optimizer_state, opt_update_fn = optimizer_state
+  per_device_rngs = jax.random.split(rng, jax.local_device_count())
+  if hasattr(hyperparameters, 'label_smoothing'):
+    label_smoothing = hyperparameters.label_smoothing
+  else:
+    label_smoothing = 0.0
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+  outputs = pmapped_train_step(workload,
+                               opt_update_fn,
+                               model_state,
+                               optimizer_state,
+                               current_param_container,
+                               batch,
+                               per_device_rngs,
+                               grad_clip,
+                               label_smoothing)
+  new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
+
+  # Log loss, grad_norm.
+  if global_step % 100 == 0 and workload.metrics_logger is not None:
+    workload.metrics_logger.append_scalar_metrics(
+        {
+            'loss': loss[0],
+            'grad_norm': grad_norm[0],
+        }, global_step)
+  return (new_optimizer_state, opt_update_fn), new_params, new_model_state
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
new file mode 100644
index 000000000..dd42743e2
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
@@ -0,0 +1,360 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
+
+import functools
+
+# isort: off
+# We have to turn off isort here to resolve a conflict between isort and yapf.
+from typing import (Any,
+                    Callable,
+                    Dict,
+                    Iterator,
+                    List,
+                    NamedTuple,
+                    Optional,
+                    Tuple,
+                    Union)
+# isort: on
+
+import chex
+from flax import jax_utils
+import jax
+from jax import lax
+import jax.numpy as jnp
+import optax
+
+from algorithmic_efficiency import spec
+
+_GRAD_CLIP_EPS = 1e-6
+
+HPARAMS = {
+    "dropout_rate": 0.1,
+    "learning_rate": 0.0017486387539278373,
+    "one_minus_beta1": 0.06733926164,
+    "beta2": 0.9955159689799007,
+    "weight_decay": 0.08121616522670176,
+    "warmup_factor": 0.02
+}
+
+
+# Forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
+def nadamw(
+    learning_rate: Union[float, optax.Schedule],
+    b1: float = 0.9,
+    b2: float = 0.999,
+    eps: float = 1e-8,
+    eps_root: float = 0.0,
+    debias: bool = True,
+    weight_decay: float = 0.0,
+    weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
+                                                    Any]]] = None,
+) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the official PyTorch
+  implementation also follows this).
+  Current code implements a simpler version with no momentum decay and slightly
+  different bias correction terms. The exact description can be found here
+  https://arxiv.org/pdf/1910.05446.pdf (Table 1).
+
+  Args:
+    learning_rate: A fixed global scaling factor.
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    weight_decay: Strength of the weight decay regularization. Note that this
+      weight decay is multiplied with the learning rate. This is consistent with
+      other frameworks such as PyTorch, but different from (Loshchilov et al,
+      2019) where the weight decay is only multiplied with the "schedule
+      multiplier", but not the base learning rate.
+    weight_decay_mask: A tree with same structure as (or a prefix of) the params
+      PyTree, or a Callable that returns such a pytree given the params/updates.
+      The leaves should be booleans, `True` for leaves/subtrees you want to
+      apply the weight decay to, and `False` for those you want to skip. Note
+      that the Nadam gradient transformations are applied to all parameters.
+
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  return optax.chain(
+      scale_by_nadam(b1, b2, eps, eps_root, debias),
+      optax.add_decayed_weights(weight_decay, weight_decay_mask),
+      scale_by_learning_rate(learning_rate))
+
+
+# All functions below are forked from
+# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
+def scale_by_nadam(b1: float = 0.9,
+                   b2: float = 0.999,
+                   eps: float = 1e-8,
+                   eps_root: float = 0.0,
+                   debias: bool = True,
+                   power: float = 0.5) -> optax.GradientTransformation:
+  """Rescale updates according to the NAdam algorithm.
+
+  References:
+  There seem to be multiple versions of NAdam. The original version is here
+  https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ (the pytorch imp. also
+  follows this).
+
+  Current code implements a simpler version with no momentum decay and slightly
+  different (standard Adam) bias correction terms. The exact description can be
+  found here https://arxiv.org/pdf/1910.05446.pdf (Table 1)
+
+  Args:
+    b1: Decay rate for the exponentially weighted average of grads.
+    b2: Decay rate for the exponentially weighted average of squared grads.
+    eps: Term added to the denominator to improve numerical stability.
+    eps_root: Term added to the denominator inside the square-root to improve
+      numerical stability when backpropagating gradients through the rescaling.
+    debias: Whether to use bias correction.
+    power: The power to use in the preconditioner (0.5 in default adam).
+  Returns:
+    An (init_fn, update_fn) tuple.
+  """
+  raise_power = jnp.sqrt if power == 0.5 else lambda x: jnp.power(x, power)
+
+  def init_fn(params):
+    mu = jax.tree_map(jnp.zeros_like, params)  # First moment
+    nu = jax.tree_map(jnp.zeros_like, params)  # Second moment
+    return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
+
+  def update_fn(updates, state, params=None):
+    del params
+    mu = _update_moment(updates, state.mu, b1, 1)
+    nu = _update_moment(updates, state.nu, b2, 2)
+    count = state.count + jnp.array(1, dtype=jnp.int32)
+    mu_hat = _update_moment(updates, mu, b1, 1)
+    mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
+    nu_hat = nu if not debias else _bias_correction(nu, b2, count)
+    updates = jax.tree_map(
+        lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+    return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
+
+  return optax.GradientTransformation(init_fn, update_fn)
+
+
+class ScaleByAdamState(NamedTuple):
+  """State for the NAdam algorithm."""
+  count: chex.Array  # shape=(), dtype=jnp.int32.
+  mu: optax.Updates
+  nu: optax.Updates
+
+
+def _update_moment(updates, moments, decay, order):
+  """Compute the exponential moving average of the `order-th` moment."""
+  return jax.tree_map(
+      lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+
+
+def _bias_correction(moment, decay, count):
+  """Perform bias correction. This becomes a no-op as count goes to infinity."""
+  beta = 1 - decay**count
+  return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)
+
+
+def scale_by_learning_rate(learning_rate, flip_sign=True):
+  m = -1 if flip_sign else 1
+  if callable(learning_rate):
+    return optax.scale_by_schedule(lambda count: m * learning_rate(count))
+  return optax.scale(m * learning_rate)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_params
+  del model_state
+  del rng
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  def jax_cosine_warmup(step_hint: int, hyperparameters):
+    # Create learning rate schedule.
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup_fn = optax.linear_schedule(
+        init_value=0.,
+        end_value=hyperparameters.learning_rate,
+        transition_steps=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_fn = optax.cosine_decay_schedule(
+        init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+    schedule_fn = optax.join_schedules(
+        schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+    return schedule_fn
+
+  # Create optimizer + LR schedule.
+  lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
+  opt_init_fn, opt_update_fn = nadamw(
+      learning_rate=lr_schedule_fn,
+      b1=1.0 - hyperparameters.one_minus_beta1,
+      b2=hyperparameters.beta2,
+      eps=1e-8,
+      weight_decay=hyperparameters.weight_decay)
+  params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple),
+                                   workload.param_shapes)
+  optimizer_state = opt_init_fn(params_zeros_like)
+
+  return jax_utils.replicate(optimizer_state), opt_update_fn
+
+
+@functools.partial(
+    jax.pmap,
+    axis_name='batch',
+    in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+    static_broadcasted_argnums=(0, 1),
+    donate_argnums=(2, 3, 4))
+def pmapped_train_step(workload,
+                       opt_update_fn,
+                       model_state,
+                       optimizer_state,
+                       current_param_container,
+                       batch,
+                       rng,
+                       grad_clip,
+                       label_smoothing):
+
+  def _loss_fn(params):
+    """Loss function used for training."""
+    logits, new_model_state = workload.model_fn(
+        params,
+        batch,
+        model_state,
+        spec.ForwardPassMode.TRAIN,
+        rng,
+        update_batch_norm=True)
+    loss_dict = workload.loss_fn(
+        label_batch=batch['targets'],
+        logits_batch=logits,
+        mask_batch=batch.get('weights'),
+        label_smoothing=label_smoothing)
+    summed_loss = loss_dict['summed']
+    n_valid_examples = loss_dict['n_valid_examples']
+    return summed_loss, (n_valid_examples, new_model_state)
+
+  grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
+  (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
+      current_param_container)
+  # Get correct global mean loss and grad.
+  (summed_loss, n_valid_examples, grad) = lax.psum(
+      (summed_loss, n_valid_examples, grad), axis_name='batch')
+  loss = summed_loss / n_valid_examples
+  grad = jax.tree_map(lambda x: x / n_valid_examples, grad)
+
+  grad_norm = jnp.sqrt(
+      sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+
+  if grad_clip is not None:
+    grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
+    grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
+    grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad)
+
+  updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
+                                               current_param_container)
+  updated_params = optax.apply_updates(current_param_container, updates)
+  return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  optimizer_state, opt_update_fn = optimizer_state
+  per_device_rngs = jax.random.split(rng, jax.local_device_count())
+  if hasattr(hyperparameters, 'label_smoothing'):
+    label_smoothing = hyperparameters.label_smoothing
+  else:
+    label_smoothing = 0.0
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+  outputs = pmapped_train_step(workload,
+                               opt_update_fn,
+                               model_state,
+                               optimizer_state,
+                               current_param_container,
+                               batch,
+                               per_device_rngs,
+                               grad_clip,
+                               label_smoothing)
+  new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
+
+  # Log loss, grad_norm.
+  if global_step % 100 == 0 and workload.metrics_logger is not None:
+    workload.metrics_logger.append_scalar_metrics(
+        {
+            'loss': loss[0],
+            'grad_norm': grad_norm[0],
+        }, global_step)
+  return (new_optimizer_state, opt_update_fn), new_params, new_model_state
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
new file mode 100644
index 000000000..57da48167
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
@@ -0,0 +1,362 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
+
+import math
+from typing import Dict, Iterator, List, Tuple
+
+from absl import logging
+import torch
+from torch import Tensor
+import torch.distributed.nn as dist_nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.lr_scheduler import SequentialLR
+
+from algorithmic_efficiency import spec
+from algorithmic_efficiency.pytorch_utils import pytorch_setup
+
+USE_PYTORCH_DDP = pytorch_setup()[0]
+
+HPARAMS = {
+    "dropout_rate": 0.1,
+    "learning_rate": 0.0017486387539278373,
+    "one_minus_beta1": 0.06733926164,
+    "beta2": 0.9955159689799007,
+    "weight_decay": 0.08121616522670176,
+    "warmup_factor": 0.02
+}
+
+
+# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
+class NAdamW(torch.optim.Optimizer):
+  r"""Implements NAdamW algorithm.
+
+    See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+    the NAdam algorithm (there is also a comment in the code which highlights
+    the only difference of NAdamW and AdamW).
+    For further details regarding the algorithm we refer to
+    `Decoupled Weight Decay Regularization`_.
+
+    Args:
+      params (iterable): iterable of parameters to optimize or dicts defining
+          parameter groups
+      lr (float, optional): learning rate (default: 1e-3)
+      betas (Tuple[float, float], optional): coefficients used for computing
+          running averages of gradient and its square (default: (0.9, 0.999))
+      eps (float, optional): term added to the denominator to improve
+          numerical stability (default: 1e-8)
+      weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+  """
+
+  def __init__(self,
+               params,
+               lr=1e-3,
+               betas=(0.9, 0.999),
+               eps=1e-8,
+               weight_decay=1e-2):
+    if not 0.0 <= lr:
+      raise ValueError(f'Invalid learning rate: {lr}')
+    if not 0.0 <= eps:
+      raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= betas[0] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}')
+    if not 0.0 <= betas[1] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
+    if not 0.0 <= weight_decay:
+      raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+    defaults = {
+        'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+    }
+    super().__init__(params, defaults)
+
+  def __setstate__(self, state):
+    super().__setstate__(state)
+    state_values = list(self.state.values())
+    step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
+        state_values[0]['step'])
+    if not step_is_tensor:
+      for s in state_values:
+        s['step'] = torch.tensor(float(s['step']))
+
+  @torch.no_grad()
+  def step(self, closure=None):
+    """Performs a single optimization step.
+
+        Args:
+          closure (callable, optional): A closure that reevaluates the model
+              and returns the loss.
+    """
+    self._cuda_graph_capture_health_check()
+
+    loss = None
+    if closure is not None:
+      with torch.enable_grad():
+        loss = closure()
+
+    for group in self.param_groups:
+      params_with_grad = []
+      grads = []
+      exp_avgs = []
+      exp_avg_sqs = []
+      state_steps = []
+      beta1, beta2 = group['betas']
+
+      for p in group['params']:
+        if p.grad is None:
+          continue
+        params_with_grad.append(p)
+        if p.grad.is_sparse:
+          raise RuntimeError('NAdamW does not support sparse gradients')
+        grads.append(p.grad)
+
+        state = self.state[p]
+
+        # State initialization
+        if len(state) == 0:
+          state['step'] = torch.tensor(0.)
+          # Exponential moving average of gradient values
+          state['exp_avg'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+          # Exponential moving average of squared gradient values
+          state['exp_avg_sq'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+
+        exp_avgs.append(state['exp_avg'])
+        exp_avg_sqs.append(state['exp_avg_sq'])
+        state_steps.append(state['step'])
+
+      nadamw(
+          params_with_grad,
+          grads,
+          exp_avgs,
+          exp_avg_sqs,
+          state_steps,
+          beta1=beta1,
+          beta2=beta2,
+          lr=group['lr'],
+          weight_decay=group['weight_decay'],
+          eps=group['eps'])
+
+    return loss
+
+
+def nadamw(params: List[Tensor],
+           grads: List[Tensor],
+           exp_avgs: List[Tensor],
+           exp_avg_sqs: List[Tensor],
+           state_steps: List[Tensor],
+           beta1: float,
+           beta2: float,
+           lr: float,
+           weight_decay: float,
+           eps: float) -> None:
+  r"""Functional API that performs NAdamW algorithm computation.
+    See NAdamW class for details.
+  """
+
+  if not all(isinstance(t, torch.Tensor) for t in state_steps):
+    raise RuntimeError(
+        'API has changed, `state_steps` argument must contain a list of' +
+        ' singleton tensors')
+
+  for i, param in enumerate(params):
+    grad = grads[i]
+    exp_avg = exp_avgs[i]
+    exp_avg_sq = exp_avg_sqs[i]
+    step_t = state_steps[i]
+
+    # Update step.
+    step_t += 1
+
+    # Perform stepweight decay.
+    param.mul_(1 - lr * weight_decay)
+
+    # Decay the first and second moment running average coefficient.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+    # Only difference between NAdamW and AdamW in this implementation.
+    # The official PyTorch implementation of NAdam uses a different algorithm.
+    # We undo these ops later on, which could cause numerical issues but saves
+    # us from having to make an extra copy of the gradients.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+    step = step_t.item()
+
+    bias_correction1 = 1 - beta1**step
+    bias_correction2 = 1 - beta2**step
+
+    step_size = lr / bias_correction1
+
+    bias_correction2_sqrt = math.sqrt(bias_correction2)
+    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+
+    param.addcdiv_(exp_avg, denom, value=-step_size)
+    exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_state
+  del rng
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  optimizer_state = {
+      'optimizer':
+          NAdamW(
+              model_params.parameters(),
+              lr=hyperparameters.learning_rate,
+              betas=(1.0 - hyperparameters.one_minus_beta1,
+                     hyperparameters.beta2),
+              eps=1e-8,
+              weight_decay=hyperparameters.weight_decay),
+  }
+
+  def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup = LinearLR(
+        optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
+    return SequentialLR(
+        optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+
+  optimizer_state['scheduler'] = pytorch_cosine_warmup(
+      workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+
+  return optimizer_state
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  current_model = current_param_container
+  current_model.train()
+  optimizer_state['optimizer'].zero_grad()
+
+  logits_batch, new_model_state = workload.model_fn(
+      params=current_model,
+      augmented_and_preprocessed_input_batch=batch,
+      model_state=model_state,
+      mode=spec.ForwardPassMode.TRAIN,
+      rng=rng,
+      update_batch_norm=True)
+
+  label_smoothing = (
+      hyperparameters.label_smoothing if hasattr(hyperparameters,
+                                                 'label_smoothing') else 0.0)
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+
+  loss_dict = workload.loss_fn(
+      label_batch=batch['targets'],
+      logits_batch=logits_batch,
+      mask_batch=batch.get('weights'),
+      label_smoothing=label_smoothing)
+  summed_loss = loss_dict['summed']
+  n_valid_examples = loss_dict['n_valid_examples']
+  if USE_PYTORCH_DDP:
+    # Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
+    summed_loss = dist_nn.all_reduce(summed_loss)
+    n_valid_examples = dist_nn.all_reduce(n_valid_examples)
+  loss = summed_loss / n_valid_examples
+
+  loss.backward()
+
+  if grad_clip is not None:
+    torch.nn.utils.clip_grad_norm_(
+        current_model.parameters(), max_norm=grad_clip)
+  optimizer_state['optimizer'].step()
+  optimizer_state['scheduler'].step()
+
+  # Log training metrics - loss, grad_norm, batch_size.
+  if global_step <= 100 or global_step % 500 == 0:
+    with torch.no_grad():
+      parameters = [p for p in current_model.parameters() if p.grad is not None]
+      grad_norm = torch.norm(
+          torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+    if workload.metrics_logger is not None:
+      workload.metrics_logger.append_scalar_metrics(
+          {
+              'loss': loss.item(),
+              'grad_norm': grad_norm.item(),
+          }, global_step)
+    logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
+                 global_step,
+                 loss.item(),
+                 grad_norm.item())
+
+  return (optimizer_state, current_param_container, new_model_state)
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch
diff --git a/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
new file mode 100644
index 000000000..ef6e84c94
--- /dev/null
+++ b/reference_algorithms/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
@@ -0,0 +1,362 @@
+"""Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch."""
+
+import math
+from typing import Dict, Iterator, List, Tuple
+
+from absl import logging
+import torch
+from torch import Tensor
+import torch.distributed.nn as dist_nn
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from torch.optim.lr_scheduler import LinearLR
+from torch.optim.lr_scheduler import SequentialLR
+
+from algorithmic_efficiency import spec
+from algorithmic_efficiency.pytorch_utils import pytorch_setup
+
+USE_PYTORCH_DDP = pytorch_setup()[0]
+
+HPARAMS = {
+    "dropout_rate": 0.1,
+    "learning_rate": 0.0017486387539278373,
+    "one_minus_beta1": 0.06733926164,
+    "beta2": 0.9955159689799007,
+    "weight_decay": 0.08121616522670176,
+    "warmup_factor": 0.02
+}
+
+
+# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py.
+class NAdamW(torch.optim.Optimizer):
+  r"""Implements NAdamW algorithm.
+
+    See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+    the NAdam algorithm (there is also a comment in the code which highlights
+    the only difference of NAdamW and AdamW).
+    For further details regarding the algorithm we refer to
+    `Decoupled Weight Decay Regularization`_.
+
+    Args:
+      params (iterable): iterable of parameters to optimize or dicts defining
+          parameter groups
+      lr (float, optional): learning rate (default: 1e-3)
+      betas (Tuple[float, float], optional): coefficients used for computing
+          running averages of gradient and its square (default: (0.9, 0.999))
+      eps (float, optional): term added to the denominator to improve
+          numerical stability (default: 1e-8)
+      weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+    .. _Decoupled Weight Decay Regularization:
+        https://arxiv.org/abs/1711.05101
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+  """
+
+  def __init__(self,
+               params,
+               lr=1e-3,
+               betas=(0.9, 0.999),
+               eps=1e-8,
+               weight_decay=1e-2):
+    if not 0.0 <= lr:
+      raise ValueError(f'Invalid learning rate: {lr}')
+    if not 0.0 <= eps:
+      raise ValueError(f'Invalid epsilon value: {eps}')
+    if not 0.0 <= betas[0] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 0: {betas[0]}')
+    if not 0.0 <= betas[1] < 1.0:
+      raise ValueError(f'Invalid beta parameter at index 1: {betas[1]}')
+    if not 0.0 <= weight_decay:
+      raise ValueError(f'Invalid weight_decay value: {weight_decay}')
+    defaults = {
+        'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+    }
+    super().__init__(params, defaults)
+
+  def __setstate__(self, state):
+    super().__setstate__(state)
+    state_values = list(self.state.values())
+    step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
+        state_values[0]['step'])
+    if not step_is_tensor:
+      for s in state_values:
+        s['step'] = torch.tensor(float(s['step']))
+
+  @torch.no_grad()
+  def step(self, closure=None):
+    """Performs a single optimization step.
+
+        Args:
+          closure (callable, optional): A closure that reevaluates the model
+              and returns the loss.
+    """
+    self._cuda_graph_capture_health_check()
+
+    loss = None
+    if closure is not None:
+      with torch.enable_grad():
+        loss = closure()
+
+    for group in self.param_groups:
+      params_with_grad = []
+      grads = []
+      exp_avgs = []
+      exp_avg_sqs = []
+      state_steps = []
+      beta1, beta2 = group['betas']
+
+      for p in group['params']:
+        if p.grad is None:
+          continue
+        params_with_grad.append(p)
+        if p.grad.is_sparse:
+          raise RuntimeError('NAdamW does not support sparse gradients')
+        grads.append(p.grad)
+
+        state = self.state[p]
+
+        # State initialization
+        if len(state) == 0:
+          state['step'] = torch.tensor(0.)
+          # Exponential moving average of gradient values
+          state['exp_avg'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+          # Exponential moving average of squared gradient values
+          state['exp_avg_sq'] = torch.zeros_like(
+              p, memory_format=torch.preserve_format)
+
+        exp_avgs.append(state['exp_avg'])
+        exp_avg_sqs.append(state['exp_avg_sq'])
+        state_steps.append(state['step'])
+
+      nadamw(
+          params_with_grad,
+          grads,
+          exp_avgs,
+          exp_avg_sqs,
+          state_steps,
+          beta1=beta1,
+          beta2=beta2,
+          lr=group['lr'],
+          weight_decay=group['weight_decay'],
+          eps=group['eps'])
+
+    return loss
+
+
+def nadamw(params: List[Tensor],
+           grads: List[Tensor],
+           exp_avgs: List[Tensor],
+           exp_avg_sqs: List[Tensor],
+           state_steps: List[Tensor],
+           beta1: float,
+           beta2: float,
+           lr: float,
+           weight_decay: float,
+           eps: float) -> None:
+  r"""Functional API that performs NAdamW algorithm computation.
+    See NAdamW class for details.
+  """
+
+  if not all(isinstance(t, torch.Tensor) for t in state_steps):
+    raise RuntimeError(
+        'API has changed, `state_steps` argument must contain a list of' +
+        ' singleton tensors')
+
+  for i, param in enumerate(params):
+    grad = grads[i]
+    exp_avg = exp_avgs[i]
+    exp_avg_sq = exp_avg_sqs[i]
+    step_t = state_steps[i]
+
+    # Update step.
+    step_t += 1
+
+    # Perform stepweight decay.
+    param.mul_(1 - lr * weight_decay)
+
+    # Decay the first and second moment running average coefficient.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+    # Only difference between NAdamW and AdamW in this implementation.
+    # The official PyTorch implementation of NAdam uses a different algorithm.
+    # We undo these ops later on, which could cause numerical issues but saves
+    # us from having to make an extra copy of the gradients.
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+    step = step_t.item()
+
+    bias_correction1 = 1 - beta1**step
+    bias_correction2 = 1 - beta2**step
+
+    step_size = lr / bias_correction1
+
+    bias_correction2_sqrt = math.sqrt(bias_correction2)
+    denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
+
+    param.addcdiv_(exp_avg, denom, value=-step_size)
+    exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
+
+
+def init_optimizer_state(workload: spec.Workload,
+                         model_params: spec.ParameterContainer,
+                         model_state: spec.ModelAuxiliaryState,
+                         hyperparameters: spec.Hyperparameters,
+                         rng: spec.RandomState) -> spec.OptimizerState:
+  """Creates a NAdamW optimizer and a learning rate schedule."""
+  del model_state
+  del rng
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  optimizer_state = {
+      'optimizer':
+          NAdamW(
+              model_params.parameters(),
+              lr=hyperparameters.learning_rate,
+              betas=(1.0 - hyperparameters.one_minus_beta1,
+                     hyperparameters.beta2),
+              eps=1e-8,
+              weight_decay=hyperparameters.weight_decay),
+  }
+
+  def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
+    warmup_steps = int(hyperparameters.warmup_factor * step_hint)
+    warmup = LinearLR(
+        optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+    cosine_steps = max(step_hint - warmup_steps, 1)
+    cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
+    return SequentialLR(
+        optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+
+  optimizer_state['scheduler'] = pytorch_cosine_warmup(
+      workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+
+  return optimizer_state
+
+
+def update_params(workload: spec.Workload,
+                  current_param_container: spec.ParameterContainer,
+                  current_params_types: spec.ParameterTypeTree,
+                  model_state: spec.ModelAuxiliaryState,
+                  hyperparameters: spec.Hyperparameters,
+                  batch: Dict[str, spec.Tensor],
+                  loss_type: spec.LossType,
+                  optimizer_state: spec.OptimizerState,
+                  eval_results: List[Tuple[int, float]],
+                  global_step: int,
+                  rng: spec.RandomState) -> spec.UpdateReturn:
+  """Return (updated_optimizer_state, updated_params, updated_model_state)."""
+  del current_params_types
+  del loss_type
+  del eval_results
+  del hyperparameters
+
+  hyperparameters = HPARAMS
+
+  current_model = current_param_container
+  current_model.train()
+  optimizer_state['optimizer'].zero_grad()
+
+  logits_batch, new_model_state = workload.model_fn(
+      params=current_model,
+      augmented_and_preprocessed_input_batch=batch,
+      model_state=model_state,
+      mode=spec.ForwardPassMode.TRAIN,
+      rng=rng,
+      update_batch_norm=True)
+
+  label_smoothing = (
+      hyperparameters.label_smoothing if hasattr(hyperparameters,
+                                                 'label_smoothing') else 0.0)
+  if hasattr(hyperparameters, 'grad_clip'):
+    grad_clip = hyperparameters.grad_clip
+  else:
+    grad_clip = None
+
+  loss_dict = workload.loss_fn(
+      label_batch=batch['targets'],
+      logits_batch=logits_batch,
+      mask_batch=batch.get('weights'),
+      label_smoothing=label_smoothing)
+  summed_loss = loss_dict['summed']
+  n_valid_examples = loss_dict['n_valid_examples']
+  if USE_PYTORCH_DDP:
+    # Use dist_nn.all_reduce to ensure correct loss and gradient scaling.
+    summed_loss = dist_nn.all_reduce(summed_loss)
+    n_valid_examples = dist_nn.all_reduce(n_valid_examples)
+  loss = summed_loss / n_valid_examples
+
+  loss.backward()
+
+  if grad_clip is not None:
+    torch.nn.utils.clip_grad_norm_(
+        current_model.parameters(), max_norm=grad_clip)
+  optimizer_state['optimizer'].step()
+  optimizer_state['scheduler'].step()
+
+  # Log training metrics - loss, grad_norm, batch_size.
+  if global_step <= 100 or global_step % 500 == 0:
+    with torch.no_grad():
+      parameters = [p for p in current_model.parameters() if p.grad is not None]
+      grad_norm = torch.norm(
+          torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+    if workload.metrics_logger is not None:
+      workload.metrics_logger.append_scalar_metrics(
+          {
+              'loss': loss.item(),
+              'grad_norm': grad_norm.item(),
+          }, global_step)
+    logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
+                 global_step,
+                 loss.item(),
+                 grad_norm.item())
+
+  return (optimizer_state, current_param_container, new_model_state)
+
+
+def get_batch_size(workload_name):
+  # Return the global batch size.
+  if workload_name == 'criteo1tb':
+    return 262_144
+  elif workload_name == 'fastmri':
+    return 32
+  elif workload_name == 'imagenet_resnet':
+    return 1024
+  elif workload_name == 'imagenet_vit':
+    return 1024
+  elif workload_name == 'librispeech_conformer':
+    return 256
+  elif workload_name == 'librispeech_deepspeech':
+    return 256
+  elif workload_name == 'ogbg':
+    return 512
+  elif workload_name == 'wmt':
+    return 128
+  elif workload_name == 'mnist':
+    return 16
+  else:
+    raise ValueError(f'Unsupported workload name: {workload_name}.')
+
+
+def data_selection(workload: spec.Workload,
+                   input_queue: Iterator[Dict[str, spec.Tensor]],
+                   optimizer_state: spec.OptimizerState,
+                   current_param_container: spec.ParameterContainer,
+                   model_state: spec.ModelAuxiliaryState,
+                   hyperparameters: spec.Hyperparameters,
+                   global_step: int,
+                   rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+  """Select data from the infinitely repeating, pre-shuffled input queue.
+  Each element of the queue is a batch of training examples and labels.
+  """
+  del workload
+  del optimizer_state
+  del current_param_container
+  del model_state
+  del hyperparameters
+  del global_step
+  del rng
+  batch = next(input_queue)
+  return batch