diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 586e03d8c..4e66a2b84 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -199,6 +199,7 @@ def update_params( batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState @@ -212,6 +213,7 @@ def update_params( - The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used. - Allowed to update state for the optimizer. - Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state). +- The submission can access the elapsed training time and get further information about the evaluation through `train_state`. - The submission can access the target evaluation metric via the `workload` variable. - **A call to this function will be considered a step** - The time between a call to this function and the next call to this function will be considered the per-step time. diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 4f6c254bd..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -67,10 +67,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer, ) if isinstance(module, bn_layers): if not update_batch_norm: - module.eval() - module.momentum_backup = module.momentum + if not hasattr(module, 'momentum_backup'): + module.momentum_backup = module.momentum + # module.momentum can be float or torch.Tensor. - module.momentum = 0. * module.momentum_backup + if torch.is_tensor(module.momentum_backup): + module.momentum = torch.zeros_like(module.momentum_backup) + else: + module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): module.momentum = module.momentum_backup - module.track_running_stats = update_batch_norm diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index b8be5fcaa..381d52f32 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -403,7 +403,8 @@ def init_optimizer_state(workload: Workload, OptimizerState, List[Tuple[int, float]], int, - RandomState + RandomState, + Optional[Dict[str, Any]] ], UpdateReturn] @@ -424,7 +425,8 @@ def update_params(workload: Workload, optimizer_state: OptimizerState, eval_results: List[Tuple[int, float]], global_step: int, - rng: RandomState) -> UpdateReturn: + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +121,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +159,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..cb6287c5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -478,16 +487,13 @@ def __call__(self, inputs, input_paddings, train): var = sum_vv / count_v - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +557,10 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +616,12 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +653,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +695,11 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,9 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +120,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +129,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 502cb093e..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -40,7 +40,7 @@ class ConformerConfig: time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True input_dropout_rate: float = 0.1 - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True attention_temperature: float = 1.0 @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings): mean = (masked_inp).sum(dim=(0, 1)) / count var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: mean = self.running_mean var = self.running_var diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8473fac0f..a0db6d607 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional +from typing import Dict, Optional, Tuple from flax import jax_utils import jax @@ -56,6 +56,37 @@ def init_model_fn( params = jax_utils.replicate(params) return params, model_state + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state + def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a5ee3fa0a..bdf556f1c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -36,7 +36,7 @@ class DeepspeechConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 # If None, defaults to 0.1. input_dropout_rate: Optional[float] = 0.1 @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings): sum_ = dist_nn.all_reduce(sum_) var = sum_ / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() else: mean = self.running_mean var = self.running_var diff --git a/docker/scripts/startup.sh b/docker/scripts/startup.sh index 5c5a6aa49..527e8306a 100644 --- a/docker/scripts/startup.sh +++ b/docker/scripts/startup.sh @@ -132,6 +132,10 @@ while [ "$1" != "" ]; do shift TEST=$1 ;; + --additional_requirements_path) + shift + ADDITIONAL_REQUIREMENTS_PATH=$1 + ;; *) usage exit 1 @@ -140,6 +144,16 @@ while [ "$1" != "" ]; do shift done + +# Optionally install addtional dependencies +if [[ -n ${ADDITIONAL_REQUIREMENTS_PATH+x} ]]; then + echo "Installing addtional requirements..." + COMMAND="cd algorithmic-efficiency && pip install -r ${ADDITIONAL_REQUIREMENTS_PATH}" + echo $COMMAND + eval $COMMAND +fi + + if [[ ${TEST} == "true" ]]; then cd algorithmic-efficiency COMMAND="python3 tests/test_traindiffs.py" diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 5f203c5c6..36e7e5607 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 32f4e830e..07281f540 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ba56cd99f..a12523bde 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index e2c44d9c1..93b41987e 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 502b7e5b4..0d194ef7a 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -264,20 +264,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 8bc2eed95..60fc25ec4 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -264,20 +264,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index bbf548ccb..2dc29acad 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 992f769f3..6cc44cb12 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -236,20 +236,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index b2256fc5a..e8e0bf4ac 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -108,21 +108,24 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type del global_step + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index b55c31afc..c3e7a546b 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -53,21 +53,24 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index f09886215..b33c0285b 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -75,20 +75,23 @@ def loss_fn(params): return new_optimizer_state, updated_params, new_model_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del global_step diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index 8b5151c77..b868bc787 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -32,21 +32,24 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type del current_params_types + del train_state del eval_results del global_step diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index ed2ee371f..0fcb9da0f 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -110,20 +110,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 5f6540020..c0eed45ef 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -190,20 +190,23 @@ def step(self, closure=None): return loss -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 5d2107ba6..e80a29693 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -110,20 +110,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 2b42bb5a4..8da4e1671 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -51,20 +51,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index e08d5b433..ebcdc9914 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,20 +118,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index da5865087..c0ecee69e 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -189,20 +189,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 1ab362dd6..271ef860b 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -144,20 +144,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 999321bd5..272a79b4c 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax @@ -67,20 +67,23 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 5f203c5c6..36e7e5607 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -252,20 +252,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ba56cd99f..a12523bde 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -224,20 +224,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 20109a9e3..a435643e4 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -144,20 +144,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index b4b8b77af..aac4146a4 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax @@ -67,20 +67,23 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 9f12c4f3f..5f45901dd 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -197,20 +197,23 @@ def _loss_fn(params, update_batch_norm=True): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index cf5e49f4f..243174d34 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -131,20 +131,23 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index b596f0bdc..294ad2706 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -113,20 +113,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 31e8a8850..7a16c07cb 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax from jax import lax @@ -69,20 +69,23 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 549d2dc58..2e2876555 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from absl import logging import torch @@ -12,20 +12,23 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/scoring/run_workloads.py b/scoring/run_workloads.py index 1c33079d3..e474b6910 100644 --- a/scoring/run_workloads.py +++ b/scoring/run_workloads.py @@ -9,9 +9,11 @@ --tuning_search_space """ +import datetime import json import os import struct +import subprocess import time from absl import app @@ -26,9 +28,11 @@ 'docker_image_url', 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', 'URL to docker image') -flags.DEFINE_integer('run_percentage', - 100, - 'Percentage of max num steps to run for.') +flags.DEFINE_integer( + 'run_percentage', + 100, + 'Percentage of max num steps to run for.' + 'Must set the flag enable_step_budget to True for this to take effect.') flags.DEFINE_string('experiment_name', 'my_experiment', 'Name of top sub directory in experiment dir.') @@ -83,10 +87,24 @@ 'If your algorithm has a smaller per step time than our baselines ' 'you may want to increase the number of steps per workload.') flags.DEFINE_string( - 'workload', + 'workloads', None, + 'String representing a comma separated list of workload names.' 'If not None, only run this workload, else run all workloads in workload_metadata_path.' ) +flags.DEFINE_string('additional_requirements_path', + None, + 'Path to requirements.txt if any.') +flags.DEFINE_integer( + 'max_steps', + None, + 'Maximum number of steps to run. Must set flag enable_step_budget.' + 'This flag takes precedence over the run_percentage flag.') +flags.DEFINE_bool( + 'enable_step_budget', + False, + 'Flag that has to be explicitly set to override time budgets to step budget percentage.' +) FLAGS = flags.FLAGS @@ -106,15 +124,40 @@ def container_running(): return True +def kill_containers(): + docker_client = docker.from_env() + containers = docker_client.containers.list() + for container in containers: + container.kill() + + +def gpu_is_active(): + output = subprocess.check_output([ + 'nvidia-smi', + '--query-gpu=utilization.gpu', + '--format=csv,noheader,nounits' + ]) + return any(int(x) > 0 for x in output.decode().splitlines()) + + def wait_until_container_not_running(sleep_interval=5 * 60): + # check gpu util + # if the gpu has not been utilized for 30 minutes kill the + gpu_last_active = datetime.datetime.now().timestamp() + while container_running(): + # check if gpus have been inactive > 45 min and if so terminate container + if gpu_is_active(): + gpu_last_active = datetime.datetime.now().timestamp() + if (datetime.datetime.now().timestamp() - gpu_last_active) > 45 * 60: + kill_containers( + "Killing container: GPUs have been inactive > 45 minutes...") time.sleep(sleep_interval) return def main(_): framework = FLAGS.framework - run_fraction = FLAGS.run_percentage / 100. experiment_name = FLAGS.experiment_name docker_image_url = FLAGS.docker_image_url submission_path = FLAGS.submission_path @@ -132,7 +175,13 @@ def main(_): study_end_index = FLAGS.study_end_index else: study_end_index = num_studies - 1 + + additional_requirements_path_flag = '' + if FLAGS.additional_requirements_path: + additional_requirements_path_flag = f'--additional_requirements_path {FLAGS.additional_requirements_path} ' + submission_id = FLAGS.submission_id + rng_seed = FLAGS.seed if not rng_seed: @@ -144,17 +193,22 @@ def main(_): with open(FLAGS.workload_metadata_path) as f: workload_metadata = json.load(f) + # Get list of all possible workloads workloads = [w for w in workload_metadata.keys()] - # Read held-out workloads + # Read heldout workloads if FLAGS.held_out_workloads_config_path: held_out_workloads = read_held_out_workloads( FLAGS.held_out_workloads_config_path) workloads = workloads + held_out_workloads - # Filter for single workload - if FLAGS.workload and (FLAGS.workload in workloads): - workloads = [FLAGS.workload] + # Filter workloads if explicit workloads specified + if FLAGS.workloads is not None: + workloads = list( + filter(lambda x: x in FLAGS.workloads.split(','), workloads)) + if len(workloads) != len(FLAGS.workloads.split(',')): + unmatched_workloads = set(FLAGS.workloads.split(',')) - set(workloads) + raise ValueError(f'Invalid workload name {unmatched_workloads}') rng_subkeys = prng.split(rng_key, num_studies) @@ -174,14 +228,22 @@ def main(_): "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches print('=' * 100) dataset = workload_metadata[base_workload_name]['dataset'] - max_steps = int(workload_metadata[base_workload_name]['max_steps'] * - run_fraction) + max_steps_flag = '' + if FLAGS.enable_step_budget: + run_fraction = FLAGS.run_percentage / 100. + if FLAGS.max_steps is None: + max_steps = int(workload_metadata[base_workload_name]['max_steps'] * + run_fraction) + else: + max_steps = FLAGS.max_steps + max_steps_flag = f'-m {max_steps}' + mount_repo_flag = '' if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' + mount_repo_flag = '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency ' + command = ('docker run -t -d -v /home/kasimbeg/data/:/data/ ' + '-v /home/kasimbeg/experiment_runs/:/experiment_runs ' + '-v /home/kasimbeg/experiment_runs/logs:/logs ' f'{mount_repo_flag}' '--gpus all --ipc=host ' f'{docker_image_url} ' @@ -190,9 +252,10 @@ def main(_): f'-s {submission_path} ' f'-w {workload} ' f'-e {study_dir} ' - f'-m {max_steps} ' + f'{max_steps_flag} ' f'--num_tuning_trials {num_tuning_trials} ' f'--rng_seed {run_seed} ' + f'{additional_requirements_path_flag}' '-c false ' '-o true ' '-i true ') @@ -235,4 +298,4 @@ def main(_): if __name__ == '__main__': flags.mark_flag_as_required('workload_metadata_path') - app.run(main) \ No newline at end of file + app.run(main) diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 02ad82fc0..1fb39d193 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -88,9 +88,15 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['time to best eval on val (s)'] = workload_df.apply( lambda x: x['accumulated_submission_time'][x['index best eval on val']], axis=1) - summary_df['time to target on val (s)'] = summary_df.apply( - lambda x: x['time to best eval on val (s)'] - if x['val target reached'] else np.inf, + workload_df['val target reached'] = workload_df[validation_metric].apply( + lambda x: target_op(x, validation_target)).apply(np.any) + workload_df['index to target on val'] = workload_df.apply( + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] else np.nan, + axis=1) + summary_df['time to target on val (s)'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][int(x[ + 'index to target on val'])] if x['val target reached'] else np.inf, axis=1) # test metrics diff --git a/setup.cfg b/setup.cfg index 321020ad9..eb570dafb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -121,6 +121,8 @@ jax_core_deps = chex==0.1.7 ml_dtypes==0.2.0 protobuf==4.25.3 + scipy==1.11.4 + # JAX CPU jax_cpu = diff --git a/submission_runner.py b/submission_runner.py index c396cb027..9f9b8ff42 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,11 +17,13 @@ import datetime import gc import importlib +from inspect import signature import itertools import json import os import struct import time +from types import MappingProxyType from typing import Any, Dict, Optional, Tuple from absl import app @@ -274,6 +276,10 @@ def train_once( hyperparameters, opt_init_rng) logging.info('Initializing metrics bundle.') + + # Check if 'train_state' is in the function signature + needs_train_state = 'train_state' in signature(update_params).parameters + # Bookkeeping. train_state = { 'validation_goal_reached': False, @@ -361,7 +367,9 @@ def train_once( optimizer_state=optimizer_state, eval_results=eval_results, global_step=global_step, - rng=update_rng) + rng=update_rng, + **({'train_state': MappingProxyType(train_state)} + if needs_train_state else {})) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 445e1f7cd..20991ab66 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from algorithmic_efficiency import spec @@ -22,17 +22,19 @@ def init_optimizer_state(workload: spec.Workload, pass -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..f107be8d7 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -408,6 +408,7 @@ def _test_submission(workload_name, workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) + print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) @@ -471,6 +472,7 @@ def _test_submission(workload_name, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state={}, eval_results=[], global_step=global_step, rng=update_rng) diff --git a/utils/run_workloads.py b/utils/run_workloads.py deleted file mode 100644 index 39f6a7b6f..000000000 --- a/utils/run_workloads.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Example Usage: -python run_workloads.py \ ---workload_config_path workload_config.json \ ---framework jax \ ---experiment_name my_first_experiment \ ---docker_image_url us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev \ ---run_percentage 10 \ ---workload_config_path workload_config.json \ ---dry_run -""" - -import json -import os -import struct -import time - -from absl import app -from absl import flags -from absl import logging - -import docker - -flags.DEFINE_string( - 'docker_image_url', - 'us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_dev', - 'URL to docker image') -flags.DEFINE_integer('run_percentage', - 100, - 'Percentage of max num steps to run for.') -flags.DEFINE_string('experiment_name', - 'my_experiment', - 'Name of top sub directory in experiment dir.') -flags.DEFINE_boolean('rsync_data', - True, - 'Whether or not to transfer the data from GCP w rsync.') -flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.') -flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.') -flags.DEFINE_boolean( - 'dry_run', - False, - 'Whether or not to actually run the docker containers. ' - 'If False, simply print the docker run commands. ') -flags.DEFINE_integer('num_studies', 1, 'Number of studies to run') -flags.DEFINE_integer('study_start_index', None, 'Start index for studies.') -flags.DEFINE_integer('study_end_index', None, 'End index for studies.') -flags.DEFINE_integer('num_tuning_trials', 1, 'Number of tuning trials.') -flags.DEFINE_integer('hparam_start_index', - None, - 'Start index for tuning trials.') -flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.') -flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.') -flags.DEFINE_integer('submission_id', - 0, - 'Submission ID to generate study and hparam seeds.') -flags.DEFINE_string( - 'workload_config_path', - 'workload_confing.json', - 'Path to config containing dataset and maximum number of steps per workload.' - 'The default values of these are set to the full budgets as determined ' - 'via the target-setting procedure. ' - 'Note that training will be interrupted at either the set maximum number ' - 'of steps or the fixed workload maximum run time, whichever comes first. ' - 'If your algorithm has a smaller per step time than our baselines ' - 'you may want to increase the number of steps per workload.') - -FLAGS = flags.FLAGS - - -def read_workloads(filename): - with open(filename, "r") as f: - held_out_workloads = json.load(f) - return held_out_workloads - - -def container_running(): - docker_client = docker.from_env() - containers = docker_client.containers.list() - if len(containers) == 0: - return False - else: - return True - - -def wait_until_container_not_running(sleep_interval=5 * 60): - while container_running(): - time.sleep(sleep_interval) - return - - -def main(_): - # What Docker image to run the container with - docker_image_url = FLAGS.docker_image_url - - # Framework - framework = FLAGS.framework - - # - run_fraction = FLAGS.run_percentage / 100. - experiment_name = FLAGS.experiment_name - - # Get study and trial interval arguments - num_studies = FLAGS.num_studies - study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0 - study_end_index = FLAGS.study_end_index if FLAGS.study_end_index else num_studies - 1 - - # Get trial arguments - num_tuning_trials = FLAGS.num_tuning_trials - hparam_start_index_flag = '' - hparam_end_index_flag = '' - if FLAGS.hparam_start_index: - hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} ' - if FLAGS.hparam_end_index: - hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} ' - - # Generate rng keys from submission_id and seed - submission_id = FLAGS.submission_id - rng_seed = FLAGS.seed - - if not rng_seed: - rng_seed = struct.unpack('I', os.urandom(4))[0] - - logging.info('Using RNG seed %d', rng_seed) - - # Read workload specifications to run - with open(FLAGS.workload_config_path) as f: - workload_config = json.load(f) - workloads = [w for w in workload_config.keys()] - - for study_index in range(study_start_index, study_end_index + 1): - print('-' * 100) - print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40) - print('-' * 100) - study_dir = os.path.join(experiment_name, f'study_{study_index}') - - for workload in workloads: - # For each runnable workload check if there are any containers running - wait_until_container_not_running() - - # Clear caches - os.system("sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") - print('=' * 100) - - # Get workload dataset, max step, algorithm path and tuning search space - dataset = workload_config[workload]['dataset'] - max_steps = int(workload_config[workload]['max_steps'] * run_fraction) - submission_path = workload_config[workload]['submission_path'] - tuning_search_space = workload_config[workload]['tuning_search_space'] - - # Optionally, define flag to mount local algorithmic-efficiency repo - mount_repo_flag = '' - if FLAGS.local: - mount_repo_flag = '-v $HOME/algorithmic-efficiency:/algorithmic-efficiency ' - - command = ('docker run -t -d -v $HOME/data/:/data/ ' - '-v $HOME/experiment_runs/:/experiment_runs ' - '-v $HOME/experiment_runs/logs:/logs ' - f'{mount_repo_flag}' - '--gpus all --ipc=host ' - f'{docker_image_url} ' - f'-d {dataset} ' - f'-f {framework} ' - f'-s {submission_path} ' - f'-w {workload} ' - f'-t {tuning_search_space} ' - f'-e {study_dir} ' - f'-m {max_steps} ' - f'--num_tuning_trials {num_tuning_trials} ' - f'{hparam_start_index_flag} ' - f'{hparam_end_index_flag} ' - f'--rng_seed {rng_seed} ' - '-c false ' - '-o true ' - '-i true ') - if not FLAGS.dry_run: - print('Running docker container command') - print('Container ID: ') - return_code = os.system(command) - else: - return_code = 0 - if return_code == 0: - print( - f'SUCCESS: container for {framework} {workload} launched successfully' - ) - print(f'Command: {command}') - print(f'Results will be logged to {experiment_name}') - else: - print( - f'Failed: container for {framework} {workload} failed with exit code {return_code}.' - ) - print(f'Command: {command}') - wait_until_container_not_running() - os.system( - "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches - - print('=' * 100) - - -if __name__ == '__main__': - app.run(main) diff --git a/utils/target_setting_workload_config.json b/utils/target_setting_workload_config.json deleted file mode 100644 index a8c050422..000000000 --- a/utils/target_setting_workload_config.json +++ /dev/null @@ -1,195 +0,0 @@ -{ - "imagenet_resnet": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json" - }, - "imagenet_resnet_gelu": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_momentum.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json" - }, - "imagenet_resnet_large_bn_init": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_momentum.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json" - }, - "imagenet_resnet_silu": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json" - }, - "imagenet_vit": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json" - }, - "imagenet_vit_glu": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json" - }, - "imagenet_vit_map": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json" - }, - "imagenet_vit_post_ln": { - "max_steps": 186666, - "dataset": "imagenet", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json" - }, - "fastmri": { - "max_steps": 36189, - "dataset": "fastmri", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nesterov.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json" - }, - "fastmri_layernorm": { - "max_steps": 36189, - "dataset": "fastmri", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json" - }, - "fastmri_model_size": { - "max_steps": 36189, - "dataset": "fastmri", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json" - }, - "fastmri_tanh": { - "max_steps": 36189, - "dataset": "fastmri", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json" - }, - "ogbg": { - "max_steps": 80000, - "dataset": "ogbg", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nesterov.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json" - }, - "ogbg_gelu": { - "max_steps": 80000, - "dataset": "ogbg", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json" - }, - "ogbg_model_size": { - "max_steps": 80000, - "dataset": "ogbg", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json" - }, - "ogbg_silu": { - "max_steps": 80000, - "dataset": "ogbg", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json" - }, - "wmt": { - "max_steps": 133333, - "dataset": "wmt", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json" - }, - "wmt_attention_temp": { - "max_steps": 133333, - "dataset": "wmt", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json" - }, - "wmt_glu_tanh": { - "max_steps": 133333, - "dataset": "wmt", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json" - }, - "wmt_post_ln": { - "max_steps": 133333, - "dataset": "wmt", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json" - }, - "librispeech_deepspeech": { - "max_steps": 48000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json" - }, - "librispeech_deepspeech_no_resnet": { - "max_steps": 48000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json" - }, - "librispeech_deepspeech_norm_and_spec_aug": { - "max_steps": 48000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json" - }, - "librispeech_deepspeech_tanh": { - "max_steps": 48000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json" - }, - "criteo1tb": { - "max_steps": 10666, - "dataset": "criteo1tb", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json" - }, - "criteo1tb_embed_init": { - "max_steps": 10666, - "dataset": "criteo1tb", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json" - }, - "criteo1tb_layernorm": { - "max_steps": 10666, - "dataset": "criteo1tb", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json" - }, - "criteo1tb_resnet": { - "max_steps": 10666, - "dataset": "criteo1tb", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json" - }, - "librispeech_conformer": { - "max_steps": 80000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json" - }, - "librispeech_conformer_attention_temperature": { - "max_steps": 80000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_adamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json" - }, - "librispeech_conformer_gelu": { - "max_steps": 80000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json" - }, - "librispeech_conformer_layernorm": { - "max_steps": 80000, - "dataset": "librispeech", - "submission_path": "reference_algorithms/target_setting_algorithms/jax_nadamw.py", - "tuning_search_space": "reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json" - } - -} \ No newline at end of file