diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..b71e42e01 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Change log + +## TODO: algorithmic-efficiency 0.1.0 + +First release of AlgoPerf benchmarking code. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 771e77f0a..025cb6d30 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -253,10 +253,9 @@ pylint submission_runner.py pylint tests ``` -### Unit and integration tests - -We run unit tests and integration tests as part of the of github actions as well. -You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/development_algorithms/`. +## Unit and integration tests +We run unit tests and integration tests as part of the of github actions as well. +You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/target_setting_algorithms/`. ### Regression tests diff --git a/README.md b/README.md index 58a62ebd5..dd0d7fe3e 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,19 @@ - [Pytorch](#pytorch) - [Rules](#rules) - [Contributing](#contributing) -- [Note on shared data pipelines between JAX and PyTorch](#note-on-shared-data-pipelines-between-jax-and-pytorch) +- [Shared data pipelines between JAX and PyTorch](#shared-data-pipelines-between-jax-and-pytorch) +- [Setup and Platform](#setup-and-platform) + - [My machine only has one GPU. How can I use this repo?](#my-machine-only-has-one-gpu-how-can-i-use-this-repo) + - [How do I run this on my SLURM cluster?](#how-do-i-run-this-on-my-slurm-cluster) + - [How can I run this on my AWS/GCP/Azure cloud project?](#how-can-i-run-this-on-my-awsgcpazure-cloud-project) +- [Submissions](#submissions) + - [Can submission be structured using multiple files?](#can-submission-be-structured-using-multiple-files) + - [Can I install custom dependencies?](#can-i-install-custom-dependencies) + - [How can I know if my code can be run on benchmarking hardware?](#how-can-i-know-if-my-code-can-be-run-on-benchmarking-hardware) + - [Are we allowed to use our own hardware to self-report the results?](#are-we-allowed-to-use-our-own-hardware-to-self-report-the-results) + + + ## Installation @@ -146,9 +158,15 @@ To use the Docker container as an interactive virtual environment, you can run a \ --keep_container_alive true ``` - -2. Open a bash terminal - + Note: You may have to use double quotes around `algorithmic-efficiency` [path] in the mounting `-v` flag. If the above command fails try replacing the following line: + ```bash + -v $HOME/algorithmic-efficiency:/algorithmic-efficiency2 \ + ``` + with + ``` + -v $HOME"/algorithmic-efficiency:/algorithmic-efficiency" \ + ``` + - Open a bash terminal ```bash docker exec -it /bin/bash ``` @@ -258,9 +276,66 @@ The rules for the MLCommons Algorithmic Efficency benchmark can be found in the If you are interested in contributing to the work of the working group, feel free to [join the weekly meetings](https://mlcommons.org/en/groups/research-algorithms/), open issues. See our [CONTRIBUTING.md](CONTRIBUTING.md) for MLCommons contributing guidelines and setup and workflow instructions. -## Note on shared data pipelines between JAX and PyTorch + +# Disclaimers + +## Shared data pipelines between JAX and PyTorch The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads are using the same TensorFlow input pipelines. Due to differences in how Jax and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads. Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details. While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example. + +## Pytorch Conformer CUDA OOM + +The conformer pytorch workload may run out of memory in current state. Please set the `submission_runner.py` flag `reduce_pytorch_max_split_size` to `True` as a temporary workaround if you encounter this issue. This will set 'max_split_size_mb:256'. Note that this will adversely impact the performance of the submission on this workload. See [tracking issue](https://github.com/mlcommons/algorithmic-efficiency/issues/497). + + +# FAQS + +## Setup and Platform + +### My machine only has one GPU. How can I use this repo? +You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes in our reference algorithms `algorithmic-efficiency/baselines` and `algorithmic-efficiency/reference_algorithms` are tuned for a machine with 8 16GB V100 GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit' +on the benchmarking hardware, so if you are using fewer +GPUs with higher per GPU memory, please monitor your memory usage +to make make sure it will fit on 8xV100 GPUs with 16GB of VRAM per card. + +### How do I run this on my SLURM cluster? +You may run into issues with `sudo` and `docker` on a SLURM cluster. To run the workloads in a SLURM cluster you can use Apptainer (previously Singularity), see this [section](using-singularity/apptainer-instead-of-docker). +### How can I run this on my AWS/GCP/Azure cloud project? + Depending on your virtual machine, you may have to install the correct GPU drivers and the NVIDIA Docker toolkit. For example, in GCP you will have to do the following. +1. If you don't have a VM instance yet, we recommend creating a +new Compute Instance with the "Deep Learning on Linux" Image in Boot disk options. +2. To install the NVIDIA Docker toolkit, you can use `scripts/cloud-startup.sh` as a startup script for the VM. This will automate the installation of the NVIDIA GPU Drivers and NVIDIA Docker toolkit. + +## Submissions +### Can submission be structured using multiple files? +Yes, your submission can be structured using multiple files. +### Can I install custom dependencies? +You may use custom dependencies as long as they do not conflict with any of the pinned packages in `algorithmic-efficiency/setup.cfg`. +To include your custom dependencies in your submission, please include them in a requirements.txt file. Please refer to the [Software dependencies](https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#software-dependencies) section of our rules. +### How can I know if my code can be run on benchmarking hardware? +The benchmarking hardware specifications are documented in the [Getting Started Document](./getting_started.md). +We recommend monitoring your submission's memory usage so that it does not exceed the available memory +on the competition hardware. We also recommend to do a dry run using a cloud instance. +### Are we allowed to use our own hardware to self-report the results? +You only have to use the competition hardware for runs that are directly involved in the scoring procedure. This includes all runs for the self-tuning ruleset, but only the runs of the best hyperparameter configuration in each study for the external tuning ruleset. For example, you could use your own (different) hardware to tune your submission and identify the best hyperparameter configuration (in each study) and then only run this configuration (i.e. 5 runs, one for each study) on the competition hardware. + +# Citing AlgoPerf Benchmark +If you use the **AlgoPerf** Benchmark in your work, please consider citing: + +> [George E. Dahl, Frank Schneider, Zachary Nado, et al.
+> **Benchmarking Neural Network Training Algorithms**
+> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179) + +```bibtex +@misc{dahl2023algoperf, + title={{Benchmarking Neural Network Training Algorithms}}, + author={Dahl, George E. and Schneider, Frank and Nado, Zachary and Agarwal, Naman and Sastry, Chandramouli Shama and Hennig, Philipp and Medapati, Sourabh and Eschenhagen, Runa and Kasimbeg, Priya and Suo, Daniel and Bae, Juhan and Gilmer, Justin and Peirson, Abel L. and Khan, Bilal and Anil, Rohan and Rabbat, Mike and Krishnan, Shankar and Snider, Daniel and Amid, Ehsan and Chen, Kongtao and Maddison, Chris J. and Vasudev, Rakshith and Badura, Michal and Garg, Ankush and Mattson, Peter}, + year={2023}, + eprint={2306.07179}, + archivePrefix={arXiv}, + primaryClass={cs.LG} +} +``` \ No newline at end of file diff --git a/SUBMISSION_PROCESS_RULES.md b/SUBMISSION_PROCESS_RULES.md index da54cbbec..227d6128b 100644 --- a/SUBMISSION_PROCESS_RULES.md +++ b/SUBMISSION_PROCESS_RULES.md @@ -1,6 +1,6 @@ # MLCommons™ AlgoPerf: Submission Process Rules -**Version:** 0.0.2 *(Last updated 03 Oktober 2023)* +**Version:** 0.0.3 *(Last updated 10 Oktober 2023)* ## Table of Contents @@ -39,8 +39,8 @@ Three additional documents complement this document: ### Dates -- **Publication of the call for submission: 17. Oktober 2023 (08:00 AM UTC)** -- Registration deadline for submissions: 15. December 2023 (08:00 AM UTC) +- **Publication of the call for submission: 17. October 2023 (08:00 AM UTC)** +- Registration deadline to express non-binding intent to submit: 15. December 2023 (08:00 AM UTC) - Version freeze for the benchmark codebase: 17. January 2024 (08:00 AM UTC) - **Submission deadline: 15. February 2024 (08:00 AM UTC)** - Sampling the held-out workloads and hyperparameters: 16. February 2024 (08:00 AM UTC) @@ -68,19 +68,18 @@ For a guide on the technical steps and details on how to write a submission, ple In the following, we describe the logistical steps required to submit a training algorithm to the AlgoPerf: Training Algorithms Benchmark. -### Register a submission +### Register an intent to submit All submitters need to register an intent to submit before the submission registration deadline. This registration is mandatory, i.e. required for all submissions, but not binding, i.e. you don't have to submit a registered submission. This registration is necessary, to estimate the number of submissions and provide support for potential submitters. -To register a submission, please fill out this [online form](https://forms.gle/iY1bUhwSjj1JZ4fa9) with the following information +To register an intent to submission, please fill out this [online form](https://forms.gle/iY1bUhwSjj1JZ4fa9) with the following information - Name of the submission (e.g. name of the algorithm, or any other arbitrary identifier). - Ruleset under which the submission will be scored. -- Name of all submitters associated with this submission. -- Email of all submitters associated with this submission. -- Affiliations of all submitters associated with this submission. +- Name, email, and affiliations of all submitters associated with this submission. +- Interest in compute support. -In return, the submission will be issued a unique **submission ID** that will be used throughout the submission process. +The submission will be issued a unique **submission ID** that will be used throughout the submission process. ### How to submit @@ -143,9 +142,9 @@ The spirit jury may then hear the justifications of the submitters, inspect the ## Awards and prize money -An awards committee will award a prize for the "*Best Performance*" in each ruleset as well as a "*Jury Award*". The prize for the best-performing submission will take into account the [benchmark score](RULES.md#benchmark-score-using-performance-profiles) on the full benchmark. The "*Jury Award*" will favor more out-of-the-box ideas that show great potential, even though the method may not be of practical value with the current landscape of models, software, etc. +An awards committee will award a prize for the "*Best Performance*" in each ruleset as well as a "*Innovative Submission Award*". The prize for the best-performing submission will take into account the [benchmark score](RULES.md#benchmark-score-using-performance-profiles) on the full benchmark. The "*Innovative Submission Award*" will favor more out-of-the-box ideas that show great potential, even though the method may not be of practical value with the current landscape of models, software, etc. -The prize money for "*Best Performance*" in a ruleset is $20,000 each. The winner of the "*Jury Award*" will be awarded $10,000. We reserve the right to split the prize money and distribute it among multiple submissions. +The prize money for "*Best Performance*" in a ruleset is $20,000 each. The winner of the "*Innovative Submission Award*" will be awarded $10,000. We reserve the right to split the prize money and distribute it among multiple submissions. If a submission is ineligible to win prize money it can still win an award. The prize money will then go to the highest-ranking eligible submission. @@ -159,10 +158,10 @@ The awards committee will be responsible for awarding prize money to submissions To ensure a fair process and avoid conflicts of interest, some individuals and institutions are ineligible to win prize money. This includes: -- The chairs of the MLCommons Algorithms Working Group (presently *George Dahl* and *Frank Schneider*) and their institutions (currently *Google Inc.* and the *University of Tübingen*) -- All individuals serving on the awards committee and their institutions. +- The chairs of the MLCommons Algorithms Working Group (presently *George Dahl* and *Frank Schneider*) and their associated institutions (currently *Google Inc.* and the *University of Tübingen*) +- All individuals serving on the awards committee and their associated institutions. -A submission with at least one ineligible submitter may still win an award, but the prize money will then be awarded to the top-ranked submission that is eligible for prize money. +A submission with at least one participating ineligible entity may still win an award, but the prize money will then be given to the top-ranked submission that does not contain ineligible entities. Additionally, we require members of the spirit jury to abstain from being involved in a review if: diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 2b3cf86f6..b7bde226a 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -9,7 +9,7 @@ import shutil import subprocess import sys -from typing import Any, Optional +from typing import Any, Dict, Optional from absl import flags from clu import metric_writers @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters, return hparams -def write_json(name: str, log_dict: dict, indent: int = 2) -> None: +def write_json(name: str, log_dict: Dict, indent: int = 2) -> None: if RANK == 0: with open(name, 'w') as f: f.write(json.dumps(log_dict, indent=indent)) def write_to_csv( - metrics: dict, + metrics: Dict, csv_path: str, ) -> None: try: @@ -120,7 +120,7 @@ def write_to_csv( return -def _get_utilization() -> dict: +def _get_utilization() -> Dict: util_data = {} # CPU @@ -180,7 +180,7 @@ def _get_utilization() -> dict: return util_data -def _get_system_hardware_info() -> dict: +def _get_system_hardware_info() -> Dict: system_hardware_info = {} try: system_hardware_info['cpu_model_name'] = _get_cpu_model_name() @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict: return system_hardware_info -def _get_system_software_info() -> dict: +def _get_system_software_info() -> Dict: system_software_info = {} system_software_info['os_platform'] = \ @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool: return isinstance(item, primitive) -def _get_workload_properties(workload: spec.Workload) -> dict: +def _get_workload_properties(workload: spec.Workload) -> Dict: workload_properties = {} skip_list = ['param_shapes', 'model_params_types'] keys = [ @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict: return workload_properties -def get_meta_data(workload: spec.Workload) -> dict: +def get_meta_data(workload: spec.Workload, + rng_seed: Optional[int] = None) -> Dict: meta_data = {} workload_properties = _get_workload_properties(workload) meta_data.update(workload_properties) @@ -272,15 +273,11 @@ def get_meta_data(workload: spec.Workload) -> dict: meta_data.update(system_software_info) system_hardware_info = _get_system_hardware_info() meta_data.update(system_hardware_info) + if rng_seed is not None: + meta_data.update({'rng_seed': rng_seed}) return meta_data -def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): - meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) - write_json(meta_file_name, meta_data) - - class MetricLogger(object): """Used to log all measurements during training. @@ -308,7 +305,7 @@ def __init__(self, wandb.config.update(hyperparameters._asdict()) def append_scalar_metrics(self, - metrics: dict, + metrics: Dict, global_step: int, preemption_count: Optional[int] = None, is_eval: bool = False) -> None: diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 1b8612fd1..4f6c254bd 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -27,9 +27,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None: # Make sure no GPU memory is preallocated to Jax. os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Only use CPU for Jax to avoid memory issues. - # Setting the corresponding environment variable here has no effect; it has to - # be done before jax and tensorflow (!) are imported for the first time. jax.config.update('jax_platforms', 'cpu') + jax.config.update('jax_platform_name', 'cpu') # From the docs: "(...) causes cuDNN to benchmark multiple convolution # algorithms and select the fastest." torch.backends.cudnn.benchmark = True diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index ba8db9ced..a76a70289 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -6,6 +6,7 @@ from flax import jax_utils import jax import jax.numpy as jnp +import numpy as np from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -147,7 +148,8 @@ def _eval_batch(self, batch: Dict[str, spec.Tensor]) -> spec.Tensor: # We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of # shape (local_device_count,) will all be different values. - return self._eval_batch_pmapped(params, batch).sum() + return np.array( + self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64) class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload): diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index ef971bb75..13bd308fb 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -35,14 +35,14 @@ def has_reached_validation_target(self, eval_result: Dict[str, @property def validation_target_value(self) -> float: - return 0.123649 + return 0.123735 def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['test/loss'] < self.test_target_value @property def test_target_value(self) -> float: - return 0.126060 + return 0.126041 @property def loss_type(self) -> spec.LossType: diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algorithmic_efficiency/workloads/fastmri/workload.py index ecfa27547..d1d07e70e 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/workload.py @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool: @property def validation_target_value(self) -> float: - return 0.7344 + return 0.726999 def has_reached_test_target(self, eval_result: float) -> bool: return eval_result['test/ssim'] > self.test_target_value @property def test_target_value(self) -> float: - return 0.741652 + return 0.744254 @property def loss_type(self) -> spec.LossType: @@ -51,7 +51,7 @@ def num_validation_examples(self) -> int: @property def num_test_examples(self) -> int: - return 3581 + return 3548 @property def eval_batch_size(self) -> int: diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 45d77ede4..bc7eae3b8 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -144,7 +144,7 @@ def _build_input_queue( } padded_batch = data_utils.shard_and_maybe_pad_np( - numpy_batch, padding_value=1.0, global_batch_size=global_batch_size) + numpy_batch, padding_value=1.0) yield padded_batch # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 665c3c894..2da7dcfb3 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -71,7 +71,10 @@ def forward(self, x): class Subsample(nn.Module): - def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): + def __init__(self, + encoder_dim: int = 0, + input_dropout_rate: float = 0.0, + num_bins: int = 80): super().__init__() self.encoder_dim = encoder_dim self.input_dropout_rate = input_dropout_rate @@ -81,7 +84,10 @@ def __init__(self, encoder_dim: int = 0, input_dropout_rate: float = 0.0): self.conv2 = Conv2dSubsampling( input_channels=encoder_dim, output_channels=encoder_dim) - self.linear = nn.LazyLinear(out_features=self.encoder_dim, bias=True) + self.linear = nn.Linear( + in_features=self.encoder_dim * num_bins // 4, + out_features=self.encoder_dim, + bias=True) self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim) self.dropout = nn.Dropout(p=self.input_dropout_rate) @@ -123,6 +129,7 @@ def __init__(self, self.kernel = nn.Parameter( torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))) self.bias = nn.Parameter(torch.zeros(output_channels)) + self.register_buffer('paddings_kernel', torch.ones([1, 1, 1])) def get_same_padding(self, input_shape): in_height, in_width = input_shape[2:] @@ -162,15 +169,11 @@ def forward(self, inputs, paddings): input_length = paddings.shape[1] stride = self.filter_stride[0] pad_len = (input_length + stride - 1) // stride * stride - input_length - padded_paddings = torch.cat([ - paddings[:, None, :], - torch.zeros( - size=(paddings.shape[0], 1, pad_len), device=paddings.device) - ], - dim=2) + padded_paddings = F.pad( + paddings[:, None, :], (0, pad_len), mode='constant', value=0) out_padding = F.conv1d( input=padded_paddings, - weight=torch.ones([1, 1, 1], device=paddings.device), + weight=self.paddings_kernel, stride=self.filter_stride[:1]) out_padding = out_padding.squeeze(dim=1) outputs = outputs * (1 - out_padding[:, None, :, None]) @@ -184,11 +187,15 @@ def __init__(self, config: ConformerConfig): self.config = config self.ln = LayerNorm(dim=config.encoder_dim) - self.linear1 = nn.LazyLinear( + self.linear1 = nn.Linear( + in_features=config.encoder_dim, out_features=config.encoder_dim * config.feed_forward_expansion_factor, bias=True) self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate) - self.linear2 = nn.LazyLinear(out_features=config.encoder_dim, bias=True) + self.linear2 = nn.Linear( + in_features=config.encoder_dim * config.feed_forward_expansion_factor, + out_features=config.encoder_dim, + bias=True) if config.feed_forward_residual_dropout_rate is None: feed_forward_residual_dropout_rate = 0.1 @@ -253,217 +260,32 @@ def forward(self, inputs): return inputs * scale -class MHSAwithQS(nn.MultiheadAttention): - # pylint: disable=locally-disabled, use-a-generator, line-too-long, invalid-name +class MHSAwithQS(nn.Module): + def __init__(self, config: ConformerConfig): - super().__init__( - embed_dim=config.encoder_dim, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout_rate, - bias=True, - batch_first=True) + super().__init__() + self.embed_dim = config.encoder_dim + self.num_heads = config.num_attention_heads + self.dropout = config.attention_dropout_rate + self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim) + self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim) self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads) - def _scaled_in_proj_weight(self): - # Scale the query projection weight. - qs_input = self.in_proj_weight[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads, -1).transpose(1, 2) - in_proj_queryW_scaled = self.qs(qs_input).transpose( - 1, 2).view(*self.in_proj_weight[:self.embed_dim].shape) - in_proj_weight = torch.cat( - [in_proj_queryW_scaled, self.in_proj_weight[self.embed_dim:]]) - return in_proj_weight - - def _scaled_in_proj_bias(self): - # Scale the query bias. - in_proj_queryb_scaled = self.qs(self.in_proj_bias[:self.embed_dim].view( - self.num_heads, self.embed_dim // self.num_heads)).view(-1) - in_proj_bias = torch.cat( - [in_proj_queryb_scaled, self.in_proj_bias[self.embed_dim:]]) - return in_proj_bias - - def forward(self, - query, - key, - value, - key_padding_mask=None, - need_weights: bool = True, - attn_mask=None, - average_attn_weights: bool = True): - r""" - Args: - query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` - or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, - :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. - Queries are compared against key-value pairs to produce the output. - See "Attention Is All You Need" for more details. - key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` - or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, - :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. - See "Attention Is All You Need" for more details. - value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when - ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source - sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. - See "Attention Is All You Need" for more details. - key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` - to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. - Binary and byte masks are supported. - For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for - the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. - need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. - Default: ``True``. - attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape - :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, - :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be - broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. - Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the - corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the - corresponding position is not allowed to attend. For a float mask, the mask values will be added to - the attention weight. - average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across - heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an - effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) - - Outputs: - - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, - :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, - where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the - embedding dimension ``embed_dim``. - - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, - returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or - :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and - :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per - head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. - - .. note:: - `batch_first` argument is ignored for unbatched inputs. - """ - is_batched = query.dim() == 3 - if key_padding_mask is not None: - _kpm_dtype = key_padding_mask.dtype - if _kpm_dtype != torch.bool and not torch.is_floating_point( - key_padding_mask): - raise AssertionError( - "only bool and floating types of key_padding_mask are supported") - why_not_fast_path = '' - if not is_batched: - why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" - elif query is not key or key is not value: - # When lifting this restriction, don't forget to either - # enforce that the dtypes all match or test cases where - # they don't! - why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" - elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" - elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype: - # this case will fail anyway, but at least they'll get a useful error message. - why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" - elif self.training: - why_not_fast_path = "training is enabled" - elif not self.batch_first: - why_not_fast_path = "batch_first was not True" - elif self.bias_k is not None: - why_not_fast_path = "self.bias_k was not None" - elif self.bias_v is not None: - why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" - elif self.add_zero_attn: - why_not_fast_path = "add_zero_attn was enabled" - elif not self._qkv_same_embed_dim: - why_not_fast_path = "_qkv_same_embed_dim was not True" - elif attn_mask is not None: - why_not_fast_path = "attn_mask was not None" - elif query.is_nested and key_padding_mask is not None: - why_not_fast_path = "key_padding_mask is not supported with NestedTensor input" - elif self.num_heads % 2 == 1: - why_not_fast_path = "num_heads is odd" - elif torch.is_autocast_enabled(): - why_not_fast_path = "autocast is enabled" - - if not why_not_fast_path: - tensor_args = ( - query, - key, - value, - self.in_proj_weight, - self.in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - ) - # We have to use list comprehensions below because TorchScript does not support - # generator expressions. - if torch.overrides.has_torch_function(tensor_args): - why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) - for x in tensor_args]): - why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any( - [x is not None and x.requires_grad for x in tensor_args]): - why_not_fast_path = ( - "grad is enabled and at least one of query or the " - "input/output projection weights or biases requires_grad") - if not why_not_fast_path: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - return torch._native_multi_head_attention( - query, - key, - value, - self.embed_dim, - self.num_heads, - in_proj_weight, - in_proj_bias, - self.out_proj.weight, - self.out_proj.bias, - key_padding_mask if key_padding_mask is not None else attn_mask, - need_weights, - average_attn_weights, - 1 if key_padding_mask is not None else - 0 if attn_mask is not None else None) - any_nested = query.is_nested or key.is_nested or value.is_nested - assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + - f"The fast path was not hit because {why_not_fast_path}") - - if self.batch_first and is_batched: - # make sure that the transpose op does not affect the "is" property - if key is value: - if query is key: - query = key = value = query.transpose(1, 0) - else: - query, key = [x.transpose(1, 0) for x in (query, key)] - value = key - else: - query, key, value = [x.transpose(1, 0) for x in (query, key, value)] - - if not self._qkv_same_embed_dim: - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, use_separate_proj_weight=True, - q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - v_proj_weight=self.v_proj_weight, average_attn_weights=average_attn_weights) - else: - # Scale the query bias parameter and the query projection weight. - in_proj_weight = self._scaled_in_proj_weight() - in_proj_bias = self._scaled_in_proj_bias() - attn_output, attn_output_weights = F.multi_head_attention_forward( - query, key, value, self.embed_dim, self.num_heads, - in_proj_weight, in_proj_bias, - self.bias_k, self.bias_v, self.add_zero_attn, - self.dropout, self.out_proj.weight, self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, need_weights=need_weights, - attn_mask=attn_mask, average_attn_weights=average_attn_weights) - if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights - else: - return attn_output, attn_output_weights + def forward(self, inputs, key_padding_mask=None): + batch_size, seq_len, embed_dim = inputs.shape + q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2) + q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2) + out = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=~key_padding_mask[:, None, None], + dropout_p=self.dropout, + ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim) + out = self.out_proj(out) + return out class MultiHeadedSelfAttention(nn.Module): @@ -483,12 +305,9 @@ def __init__(self, config: ConformerConfig): def forward(self, outputs, paddings): outputs = self.ln(outputs) - outputs, _ = self.self_attention( - query=outputs, - key=outputs, - value=outputs, - key_padding_mask=paddings==1, - need_weights=False, + outputs = self.self_attention( + outputs, + key_padding_mask=paddings == 1, ) outputs = self.dropout(outputs) return outputs @@ -504,18 +323,29 @@ def __init__(self, config: ConformerConfig): self.register_buffer('running_var', running_var) self.scale = nn.Parameter(torch.zeros(config.encoder_dim)) self.bias = nn.Parameter(torch.zeros(config.encoder_dim)) - self.register_buffer('momentum', - torch.FloatTensor([config.batch_norm_momentum])) - self.register_buffer('epsilon', - torch.FloatTensor([config.batch_norm_epsilon])) + self.register_buffer('dim', torch.FloatTensor([config.encoder_dim])) - # self.momentum = config.batch_norm_momentum - # self.epsilon = config.batch_norm_epsilon - # self.dim = config.encoder_dim + self.momentum = config.batch_norm_momentum + self.epsilon = config.batch_norm_epsilon def forward(self, inputs, input_paddings): #inputs: NHD #padding: NH + """ + Alternatively: + inputs[input_paddings==0] = F.batch_norm( + input = inputs[input_paddings==0], + running_mean = self.running_mean, + running_var = self.running_var, + weight = 1+self.scale, + bias = self.bias, + training = self.training, + momentum=1-self.momentum, + eps=self.epsilon + ) + inputs.masked_fill(input_paddings[...,None] != 0, 0) + return inputs + """ mask = 1 - input_paddings[:, :, None] if self.training: count = mask.sum() @@ -627,7 +457,9 @@ def __init__(self, config: ConformerConfig): else: input_dropout_rate = config.input_dropout_rate self.subsample = Subsample( - encoder_dim=config.encoder_dim, input_dropout_rate=input_dropout_rate) + encoder_dim=config.encoder_dim, + input_dropout_rate=input_dropout_rate, + num_bins=preprocessing_config.num_bins) self.conformers = nn.ModuleList( [ConformerBlock(config) for _ in range(config.num_encoder_layers)]) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 24f4eb1fc..c4f4a1247 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -47,8 +47,11 @@ def init_model_fn( input_dropout_rate. """ torch.random.manual_seed(rng[0]) - # Disable cudnn benchmark to avoid OOM errors. + # Configure torch backends to avoid OOM errors. torch.backends.cudnn.benchmark = False + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.backends.cuda.enable_math_sdp(True) model = conformer_model.ConformerEncoderDecoder( conformer_model.ConformerConfig( attention_residual_dropout_rate=dropout_rate, @@ -57,13 +60,6 @@ def init_model_fn( input_dropout_rate=aux_dropout_rate, use_specaug=self.use_specaug)) self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - # Run model once to initialize lazy layers. - # Run the initialization in eval mode to disable BN tracking. - model = model.eval() - t = MAX_INPUT_LENGTH - wave = torch.randn((2, t)) - pad = torch.zeros_like(wave) - _ = model(wave, pad) conformer_model.initialize(model) self._param_shapes = param_utils.pytorch_param_shapes(model) self._param_types = param_utils.pytorch_param_types(self._param_shapes) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index dc7fb912b..2ad355975 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: Dict[str, @property def validation_target_value(self) -> float: - return 0.084952 + return 0.085884 def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: return eval_result['test/wer'] < self.test_target_value @property def test_target_value(self) -> float: - return 0.053000 + return 0.052981 @property def loss_type(self) -> spec.LossType: diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4086a5841..ac6005225 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -50,3 +50,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' + + @property + def validation_target_value(self) -> float: + return 0.119936 + + @property + def test_target_value(self) -> float: + return 0.074143 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1bb649ba8..bcdd78fb5 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -59,3 +59,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] + + @property + def validation_target_value(self) -> float: + return 0.119936 + + @property + def test_target_value(self) -> float: + return 0.074143 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py deleted file mode 100644 index f9fd30b0d..000000000 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ /dev/null @@ -1,21 +0,0 @@ -from algorithmic_efficiency.workloads.librispeech_conformer import workload - - -class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): - - @property - def validation_target_value(self) -> float: - return 0.118232 - - @property - def test_target_value(self) -> float: - return 0.073397 - - @property - def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 - - @property - def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours diff --git a/datasets/README.md b/datasets/README.md index 5ff0e18a7..586895022 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -28,7 +28,7 @@ make sure the data directory is mounted to a directory on your host with -v flag. If you are following instructions from the README you will have used the `-v $HOME/data:/data` flag in the `docker run` command. This will mount the `$HOME/data` directory to the `/data` directory in the container. -In this case set --data_dir to `\data`. +In this case set --data_dir to `/data`. ```bash DATA_DIR='/data' ``` diff --git a/reference_algorithms/development_algorithms/README.md b/reference_algorithms/development_algorithms/README.md deleted file mode 100644 index 12b1b1f8e..000000000 --- a/reference_algorithms/development_algorithms/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Development Algorithms - -These are various algorithms used during the testing and development of the codebase. - -These are not valid submissions, because they use a different hyperparameter settings and algorithms per workload. diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py deleted file mode 100644 index 4dea0c321..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_jax/submission.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Training algorithm track submission functions for Criteo1TB DLRM-Small.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 524_288 // 2 - - -def create_learning_rate_fn(workload: spec.Workload, - hparams: spec.Hyperparameters): - """Create learning rate schedule.""" - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hparams.learning_rate, - transition_steps=hparams.warmup_steps) - cosine_fn = optax.cosine_decay_schedule( - init_value=hparams.learning_rate, - decay_steps=(workload.step_hint - hparams.warmup_steps)) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], boundaries=[hparams.warmup_steps]) - return schedule_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_fn(workload, hyperparameters) - opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=learning_rate_fn, - b1=hyperparameters.beta1, - weight_decay=hyperparameters.weight_decay) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=False) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - return loss, new_model_state - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (loss, new_model_state), grad = grad_fn(current_param_container) - (loss, grad) = lax.pmean((loss, grad), axis_name='batch') - grad_norm = jnp.sqrt( - sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - # del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs) - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py b/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py deleted file mode 100644 index d9d9c29b5..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/criteo1tb_pytorch/submission.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'criteo1tb': 524_288} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del rng - del model_state - - base_lr = hyperparameters.learning_rate - - optimizer_state = { - 'optimizer': - torch.optim.AdamW( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.weight_decay, - betas=(hyperparameters.beta1, 0.999)), - } - - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-12, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], - T_max=(workload.step_hint - hyperparameters.warmup_steps), - ) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_steps]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=False) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json deleted file mode 100644 index a30292bdb..000000000 --- a/reference_algorithms/development_algorithms/criteo1tb/tuning_search_space.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "learning_rate": { - "feasible_points": [ - 0.0065686501947063445 - ] - }, - "beta1": { - "feasible_points": [ - 0.8743797750166902 - ] - }, - "beta2": { - "feasible_points": [ - 0.9980006182116233 - ] - }, - "warmup_steps": { - "feasible_points": [ - 800 - ] - }, - "weight_decay": { - "feasible_points": [ - 1.5301171352729387e-5 - ] - } -} diff --git a/reference_algorithms/development_algorithms/fastmri/__init__.py b/reference_algorithms/development_algorithms/fastmri/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py deleted file mode 100644 index 73b020112..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_jax/submission.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Training algorithm track submission functions for FastMRI in Jax.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 64 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - max_num_train_steps = 500 * steps_per_epoch - decay_epoch_period = hparams.lr_step_size * steps_per_epoch - decay_events = range(decay_epoch_period, - max_num_train_steps, - decay_epoch_period) - schedule_fn = optax.piecewise_constant_schedule( - init_value=hparams.learning_rate, - boundaries_and_scales={t: hparams.lr_gamma for t in decay_events}) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.rmsprop( - learning_rate=learning_rate_fn, - decay=0.99) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1, 4)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss - - grad_fn = jax.grad(_loss_fn) - grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del model_state - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, None - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py b/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py deleted file mode 100644 index 38828d4c3..000000000 --- a/reference_algorithms/development_algorithms/fastmri/fastmri_pytorch/submission.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Training algorithm track submission functions for FastMRI.""" - -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import StepLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'fastmri': 8} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - base_lr = hyperparameters.learning_rate * get_batch_size('fastmri') - optimizer_state = { - 'optimizer': - torch.optim.RMSprop( - model_params.parameters(), - lr=base_lr, - weight_decay=hyperparameters.l2), - } - - optimizer_state['scheduler'] = StepLR( - optimizer_state['optimizer'], - step_size=hyperparameters.lr_step_size, - gamma=hyperparameters.lr_gamma) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - outputs_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=outputs_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - steps_per_epoch = workload.num_train_examples // get_batch_size('fastmri') - if (global_step + 1) % steps_per_epoch == 0: - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json deleted file mode 100644 index 01e4e00c2..000000000 --- a/reference_algorithms/development_algorithms/fastmri/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.001]}, - "num_epochs": {"feasible_points": [50]}, - "l2": {"feasible_points": [0.0]}, - "lr_step_size": {"feasible_points": [40]}, - "lr_gamma": {"feasible_points": [0.1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py deleted file mode 100644 index 9c686d524..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_resnet') / 256. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_resnet') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.sgd( - nesterov=True, - momentum=hyperparameters.momentum, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - variables = {'params': params, **model_state} - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(variables['params']) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py deleted file mode 100644 index 694e924f7..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/imagenet_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_resnet') - base_lr = hyperparameters.learning_rate * batch_size / 256. - optimizer_state = { - 'optimizer': - torch.optim.SGD( - model_params.parameters(), - lr=base_lr, - momentum=hyperparameters.momentum, - weight_decay=hyperparameters.l2, - nesterov=True), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json deleted file mode 100644 index da969416b..000000000 --- a/reference_algorithms/development_algorithms/imagenet_resnet/tuning_search_space.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.1]}, - "warmup_epochs": {"feasible_points": [5]}, - "num_epochs": {"feasible_points": [100]}, - "l2": {"feasible_points": [1e-4]}, - "momentum": {"feasible_points": [0.9]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/imagenet_vit/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py deleted file mode 100644 index 4d65d9675..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_jax/submission.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def create_learning_rate_fn(hparams: spec.Hyperparameters, - steps_per_epoch: int): - """Create learning rate schedule.""" - base_learning_rate = hparams.learning_rate * \ - get_batch_size('imagenet_vit') / 1024. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=base_learning_rate, - transition_steps=hparams.warmup_epochs * steps_per_epoch) - cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=base_learning_rate, - decay_steps=cosine_epochs * steps_per_epoch) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hparams.warmup_epochs * steps_per_epoch]) - return schedule_fn - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - steps_per_epoch = num_train_examples // get_batch_size('imagenet_vit') - learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch) - opt_init_fn, opt_update_fn = optax.adam( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], logits) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum(jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1) - weight_penalty = hyperparameters.l2 * 0.5 * weight_l2 - loss = loss + weight_penalty - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, per_device_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py b/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py deleted file mode 100644 index eee2a01db..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/imagenet_pytorch/submission.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Training algorithm track submission functions for ImageNet.""" -from typing import Dict, Iterator, List, Tuple - -import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - batch_size = get_batch_size('imagenet_vit') - base_lr = hyperparameters.learning_rate * batch_size / 1024. - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=base_lr, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon), - } - - steps_per_epoch = workload.num_train_examples // batch_size - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_epochs * steps_per_epoch) - cosine_epochs = max( - hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_epochs * steps_per_epoch) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_epochs * steps_per_epoch]) - - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_param_container.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn( - label_batch=batch['targets'], logits_batch=logits_batch) - loss = loss_dict['summed'] / loss_dict['n_valid_examples'] - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json deleted file mode 100644 index e6cf84733..000000000 --- a/reference_algorithms/development_algorithms/imagenet_vit/tuning_search_space.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "learning_rate": {"feasible_points": [1e-3]}, - "beta1": {"feasible_points": [0.9]}, - "beta2": {"feasible_points": [0.999]}, - "epsilon": {"feasible_points": [1e-8]}, - "num_epochs": {"feasible_points": [100]}, - "warmup_epochs": {"feasible_points": [5]}, - "l2": {"feasible_points": [1e-1]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py deleted file mode 100644 index ea314b820..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_jax/submission.py +++ /dev/null @@ -1,211 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_clip = hyperparameters.grad_clip - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - scaled_grad = jax.tree_map( - lambda x: x / (grad_norm + _GRAD_CLIP_EPS) * grad_clip, grad) - grad = jax.lax.cond(grad_norm > grad_clip, - lambda _: scaled_grad, - lambda _: grad, - None) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - - if workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'train_step_ctc_loss': loss.mean(), - 'grad_norm': grad_norm.mean(), - 'learning_rate': lr, - }, - global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py deleted file mode 100644 index ce38d7509..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/librispeech_pytorch/submission.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - -device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") -ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none') - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json deleted file mode 100644 index 821288415..000000000 --- a/reference_algorithms/development_algorithms/librispeech_conformer/tuning_search_space.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.001997]}, - "beta1": {"feasible_points": [0.7132]}, - "beta2": {"feasible_points": [0.9982]}, - "epsilon": {"feasible_points": [1e-9]}, - "weight_decay": {"feasible_points":[0.026595]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [10000]}, - "training_steps" : {"feasible_points": [100000]} -} - diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py deleted file mode 100644 index f8a368f3f..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_jax/submission.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -import functools -from typing import Dict, Iterator, List, Tuple - -from absl import logging -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import numpy as np -import optax - -from algorithmic_efficiency import spec - -_GRAD_CLIP_EPS = 1e-6 - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int): - opt_init_fn, opt_update_fn = optax.inject_hyperparams(optax.adamw)( - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay, - learning_rate=0.0) - return opt_init_fn, opt_update_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_state - del rng - - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = optimizer(hyperparameters, - workload.num_train_examples) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def l2_regularization(params, l2_decay_rank_threshold): - """Computes the squared l2 norm of the given parameters. - - This function will only filter for parameters with - rank >= l2_decay_rank_threshold. So if this threshold is set to 2, then all - 1d (and lower) parameter arrays, including all bias and batch norm params, - will be ignored in this computation. - - - Args: - params: Pytree containing parameters. - l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. - - Returns: - weight_l2: the squared l2 norm of all params matching the threshold. - """ - weight_penalty_params = jax.tree_util.tree_leaves(params) - weight_l2 = sum( - jnp.sum(x**2) - for x in weight_penalty_params - if x.ndim >= l2_decay_rank_threshold) - return weight_l2 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng, - lr): - optimizer_state.hyperparams['learning_rate'] = lr - - def _loss_fn(params): - """loss function used for training.""" - (logits, logit_paddings), new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logit_paddings)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - loss = summed_loss / n_valid_examples - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - grad_norm = jnp.sqrt(l2_regularization(grad, 0)) - grad_clip = hyperparameters.grad_clip - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - - grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) - - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params, loss, grad_norm - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del loss_type - - lr = get_learning_rate(global_step, hyperparameters) - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - outputs = pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - per_device_rngs, - lr) - new_model_state, new_optimizer_state, new_params, loss, grad_norm = outputs - - if global_step <= 1000 or global_step % 100 == 0: - logging.info('%d) loss = %0.3f, grad_norm = %0.3f lr = %0.6f', - global_step, - loss.mean(), - grad_norm.mean(), - lr) - if workload.summary_writer is not None: - workload.summary_writer.scalar('train_step_ctc_loss', - loss.mean(), - global_step) - workload.summary_writer.scalar('grad_norm', grad_norm.mean(), global_step) - workload.summary_writer.scalar('learning_rate', lr, global_step) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del optimizer_state - del current_param_container - del global_step - del rng - del hyperparameters - del workload - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py b/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py deleted file mode 100644 index 9170086a5..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/librispeech_pytorch/submission.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Training algorithm track submission functions for LibriSpeech.""" -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def get_learning_rate(step, hyperparams): - warmup_steps = hyperparams.warmup_steps - if step < warmup_steps: - current_lr = (step * hyperparams.base_lr) / warmup_steps - else: - decay_factor = (1 + np.cos(step / hyperparams.training_steps * np.pi)) * 0.5 - current_lr = hyperparams.base_lr * decay_factor - return current_lr - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - optimizer = torch.optim.AdamW( - params=model_params.parameters(), - lr=0.0, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=hyperparameters.epsilon, - weight_decay=hyperparameters.weight_decay) - return {'optimizer': optimizer} - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del model_state - del loss_type - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - current_model = current_param_container - - (logits, logits_padding), _ = workload.model_fn( - current_model, - batch, - None, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], (logits, logits_padding)) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - for g in optimizer.param_groups: - g['lr'] = get_learning_rate(global_step, hyperparameters) - if hasattr(hyperparameters, 'grad_clip'): - torch.nn.utils.clip_grad_norm_( - current_model.parameters(), max_norm=hyperparameters.grad_clip) - optimizer.step() - return optimizer_state, current_param_container, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json deleted file mode 100644 index d337200c7..000000000 --- a/reference_algorithms/development_algorithms/librispeech_deepspeech/tuning_search_space.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "base_lr": {"feasible_points": [0.002632520052132928]}, - "beta1": {"feasible_points": [0.9945481149103774]}, - "beta2": {"feasible_points": [0.996379002889742]}, - "epsilon": {"feasible_points": [1e-8]}, - "weight_decay": {"feasible_points":[0.107175616660346]}, - "grad_clip": {"feasible_points": [5.0]}, - "warmup_steps" : {"feasible_points": [3000]}, - "training_steps" : {"feasible_points": [60000]} -} \ No newline at end of file diff --git a/reference_algorithms/development_algorithms/ogbg/__init__.py b/reference_algorithms/development_algorithms/ogbg/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py deleted file mode 100644 index 28b512589..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_jax/submission.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 2048} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del model_params - del model_state - del rng - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - opt_init_fn, opt_update_fn = opt_init_fn, opt_update_fn = optax.adam( - learning_rate=hyperparameters.learning_rate) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - hyperparameters, - batch, - rng): - del hyperparameters - - def _loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True) - mask_batch = batch['weights'] - loss_dict = workload.loss_fn(batch['targets'], logits_batch, mask_batch) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container) - # Get correct global mean grad. - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - pmapped_train_step = jax.pmap( - train_step, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), - static_broadcasted_argnums=(0, 1)) - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, hyperparameters, batch, dropout_rngs) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py b/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py deleted file mode 100644 index 04f4baf9a..000000000 --- a/reference_algorithms/development_algorithms/ogbg/ogbg_pytorch/submission.py +++ /dev/null @@ -1,99 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - # Return the global batch size. - batch_sizes = {'ogbg': 32768} - return batch_sizes[workload_name] - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates an Adam optimizer.""" - del workload - del model_state - del rng - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), lr=hyperparameters.learning_rate), - } - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=True) - - loss_dict = workload.loss_fn(batch['targets'], logits, batch['weights']) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - optimizer_state['optimizer'].step() - - return optimizer_state, current_param_container, new_model_state - - -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json deleted file mode 100644 index d50cc00c5..000000000 --- a/reference_algorithms/development_algorithms/ogbg/tuning_search_space.json +++ /dev/null @@ -1 +0,0 @@ -{"learning_rate": {"feasible_points": [1e-3]}} diff --git a/reference_algorithms/development_algorithms/wmt/__init__.py b/reference_algorithms/development_algorithms/wmt/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json b/reference_algorithms/development_algorithms/wmt/tuning_search_space.json deleted file mode 100644 index ba3b24f8e..000000000 --- a/reference_algorithms/development_algorithms/wmt/tuning_search_space.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "learning_rate": {"feasible_points": [0.0625]}, - "one_minus_beta_1": {"feasible_points": [0.1]}, - "dropout_rate": {"feasible_points": [0.1]}, - "aux_dropout_rate": {"feasible_points": [0.1]}, - "epsilon": {"feasible_points": [1e-9]} -} - diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py deleted file mode 100644 index 9ef1580b2..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_jax/submission.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Training algorithm track submission functions for WMT.""" - -import functools -from typing import Dict, Iterator, List, Tuple - -from flax import jax_utils -import jax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= jnp.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= jnp.sqrt(warmup_steps) - ret /= jnp.sqrt(jnp.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = jnp.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= jnp.maximum(0.0, - 0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return jnp.asarray(ret, dtype=jnp.float32) - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del model_params - del model_state - del rng - learning_rate_fn = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000) - opt_init_fn, opt_update_fn = optax.adam( - b1=1.0 - hyperparameters.one_minus_beta_1, - b2=0.98, - eps=hyperparameters.epsilon, - learning_rate=learning_rate_fn) - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - in_axes=(None, None, 0, 0, 0, 0, None), - axis_name='batch', - static_broadcasted_argnums=(0, 1, 6)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rng, - hyperparameters): - """Perform a single training step.""" - del hyperparameters - - def _loss_fn(params): - """Loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=dropout_rng, - update_batch_norm=False) - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, n_valid_examples - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, n_valid_examples), grad = grad_fn(current_param_container) - # Get correct global mean loss and grad. - (summed_loss, n_valid_examples, grad) = jax.lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') - grad = jax.tree_map(lambda x: x / n_valid_examples, grad) - - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del eval_results - del global_step - del model_state - del loss_type - - optimizer_state, opt_update_fn = optimizer_state - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - new_optimizer_state, updated_params = pmapped_train_step( - workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rngs, - hyperparameters) - return (new_optimizer_state, opt_update_fn), updated_params, None - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py b/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py deleted file mode 100644 index 2df681273..000000000 --- a/reference_algorithms/development_algorithms/wmt/wmt_pytorch/submission.py +++ /dev/null @@ -1,167 +0,0 @@ -from typing import Dict, Iterator, List, Tuple - -import numpy as np -import torch -import torch.distributed.nn as dist_nn - -from algorithmic_efficiency import spec -from algorithmic_efficiency.pytorch_utils import pytorch_setup - -USE_PYTORCH_DDP = pytorch_setup()[0] - - -def get_batch_size(workload_name): - batch_sizes = {'wmt': 128} - return batch_sizes[workload_name] - - -def create_learning_rate_scheduler( - factors='constant * linear_warmup * rsqrt_decay', - base_learning_rate=0.5, - warmup_steps=1000, - decay_factor=0.5, - steps_per_decay=20000, - steps_per_cycle=100000): - """Creates learning rate schedule. - Interprets factors in the factors string which can consist of: - * constant: interpreted as the constant value, - * linear_warmup: interpreted as linear warmup until warmup_steps, - * rsqrt_decay: divide by square root of max(step, warmup_steps) - * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1) - * decay_every: Every k steps decay the learning rate by decay_factor. - * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter. - Args: - factors: string, factors separated by "*" that defines the schedule. - base_learning_rate: float, the starting constant for the lr schedule. - warmup_steps: int, how many steps to warm up for in the warmup schedule. - decay_factor: float, the amount to decay the learning rate by. - steps_per_decay: int, how often to decay the learning rate. - steps_per_cycle: int, steps per cycle when using cosine decay. - Returns: - a function learning_rate(step): float -> {"learning_rate": float}, the - step-dependent lr. - """ - factors = [n.strip() for n in factors.split('*')] - - def step_fn(step): - """Step to learning rate function.""" - ret = 1.0 - for name in factors: - if name == 'constant': - ret *= base_learning_rate - elif name == 'linear_warmup': - ret *= np.minimum(1.0, step / warmup_steps) - elif name == 'rsqrt_decay': - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'rsqrt_normalized_decay': - ret *= np.sqrt(warmup_steps) - ret /= np.sqrt(np.maximum(step, warmup_steps)) - elif name == 'decay_every': - ret *= (decay_factor**(step // steps_per_decay)) - elif name == 'cosine_decay': - progress = np.maximum(0.0, - (step - warmup_steps) / float(steps_per_cycle)) - ret *= np.maximum(0.0, 0.5 * (1.0 + np.cos(np.pi * (progress % 1.0)))) - else: - raise ValueError(f'Unknown factor {name}.') - return ret - - return step_fn - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - del workload - del model_state - del rng - - optimizer_state = { - 'optimizer': - torch.optim.Adam( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(1.0 - hyperparameters.one_minus_beta_1, 0.98), - eps=hyperparameters.epsilon), - } - - optimizer_state['scheduler'] = create_learning_rate_scheduler( - base_learning_rate=hyperparameters.learning_rate) - return optimizer_state - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params).""" - del current_params_types - del hyperparameters - del loss_type - del eval_results - - current_model = current_param_container - current_param_container.train() - optimizer = optimizer_state['optimizer'] - optimizer.zero_grad() - - logits, _ = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - update_batch_norm=False) - - targets = batch['targets'] - weights = batch['weights'] - loss_dict = workload.loss_fn(targets, logits, weights, label_smoothing=0.1) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - if USE_PYTORCH_DDP: - # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. - summed_loss = dist_nn.all_reduce(summed_loss) - n_valid_examples = dist_nn.all_reduce(n_valid_examples) - loss = summed_loss / n_valid_examples - - loss.backward() - - lr = optimizer_state['scheduler'](global_step).item() - for g in optimizer.param_groups: - g['lr'] = lr - optimizer.step() - - return (optimizer_state, current_param_container, None) - - -# Not allowed to update the model parameters, hyperparameters, global step, or -# optimzier state. -def data_selection(workload: spec.Workload, - input_queue: Iterator[Dict[str, spec.Tensor]], - optimizer_state: spec.OptimizerState, - current_param_container: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - global_step: int, - rng: spec.RandomState) -> Dict[str, spec.Tensor]: - """Select data from the infinitely repeating, pre-shuffled input queue. - - Each element of the queue is a batch of training examples and labels. - """ - del workload - del optimizer_state - del current_param_container - del model_state - del hyperparameters - del global_step - del rng - return next(input_queue) diff --git a/reference_algorithms/target_setting_algorithms/README.md b/reference_algorithms/target_setting_algorithms/README.md index 117bbed3c..822907ba3 100644 --- a/reference_algorithms/target_setting_algorithms/README.md +++ b/reference_algorithms/target_setting_algorithms/README.md @@ -113,7 +113,7 @@ python3 submission_runner.py \ --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` ```bash @@ -123,7 +123,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json index 13bf07b4b..482a28931 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.002106913873888147 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.8231189937738506 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.8774571227688758 ] }, "warmup_steps": { "feasible_points": [ - 9999 + 1199 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.27590534177690645 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json index 106e124a0..0a9bfb3cf 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json @@ -16,7 +16,7 @@ }, "warmup_steps": { "feasible_points": [ - 1200 + 720 ] }, "weight_decay": { diff --git a/scoring/score_submission.py b/scoring/score_submission.py index 42a605dac..e8a6ac010 100644 --- a/scoring/score_submission.py +++ b/scoring/score_submission.py @@ -5,8 +5,7 @@ from absl import logging import scoring_utils -from algorithmic_efficiency import workloads -import scoring +from scoring import scoring flags.DEFINE_string( 'experiment_path', diff --git a/scoring/scoring.py b/scoring/scoring.py index fff152255..dba254233 100644 --- a/scoring/scoring.py +++ b/scoring/scoring.py @@ -40,6 +40,12 @@ WORKLOADS = workloads_registry.WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +# These global variables have to be set according to the current set of +# workloads and rules for the scoring to be correct. +# We do not use the workload registry since it contains test and development +# workloads as well. +NUM_WORKLOADS = 8 +NUM_TRIALS = 5 MIN_EVAL_METRICS = [ 'ce_loss', @@ -47,13 +53,14 @@ 'ctc_loss', 'wer', 'l1_loss', + 'loss', ] -MAX_EVAL_METRICS = ['average_precision', 'ssim', 'accuracy', 'bleu_score'] +MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] def generate_eval_cols(metrics): - splits = ['train', 'validation', 'test'] + splits = ['train', 'validation'] return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)] @@ -108,15 +115,13 @@ def get_index_that_reaches_best(workload_df, metric_col): def get_index_that_reaches_target(workload_df, validation_metric, - test_metric, - validation_target, - test_target): + validation_target): """Get the eval index in which a workload reaches the target metric_col. Args: workload_df: A subset of a submission's trials DataFrame that includes only the trials in a single workload. - metric_col: Name of array column in workload_df (e.g., `validation/l1_loss`). + metric_col: Name of array column in workload_df (e.g. `validation/l1_loss`). target: Target value for metric_col. Returns: @@ -125,26 +130,19 @@ def get_index_that_reaches_target(workload_df, """ is_minimized = check_if_minimized(validation_metric) validation_series = workload_df[validation_metric] - test_series = workload_df[test_metric] - validation_series = validation_series[validation_series != np.nan] - validation_series = validation_series[test_series != np.nan] - test_series = test_series[validation_series != np.nan] - test_series = test_series[test_series != np.nan] op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( lambda x: op(x, validation_target)) - test_target_reached = test_series.apply(lambda x: op(x, test_target)) - - target_reached = pd.Series(validation_target_reached[0] - & test_target_reached[0]) + target_reached = pd.Series(validation_target_reached) # Remove trials that never reach the target target_reached = target_reached[target_reached.apply(np.any)] - # If we have no trials that have reached the target, return -1. Else, return - # the eval index of the earliest point the target is reached. - if target_reached.empty: + # If less than 3 trials reach the target, the submission will be scored as + # missing the target on this workload; return -1. Else, return the eval index + # of the earliest point the target is reached. + if len(target_reached) < 3: return -1, -1 else: index_reached = target_reached.apply(np.argmax) @@ -188,12 +186,10 @@ def get_times_for_submission(submission, workload_init_kwargs=workload_init_kwargs) metric_name = workload_obj.target_metric_name validation_metric = f'validation/{metric_name}' - test_metric = f'test/{metric_name}' validation_target = workload_obj.validation_target_value - test_target = workload_obj.test_target_value trial_idx, time_idx = get_index_that_reaches_target( - group, validation_metric, test_metric, validation_target, test_target) + group, validation_metric, validation_target) if time_idx > -1: time_val = group[time_col].loc[trial_idx][time_idx] else: @@ -298,7 +294,7 @@ def compute_performance_profiles(results, np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0) def rho(r, tau): - return (r <= tau).sum(axis=1) / len(r.columns) + return (r <= tau).sum(axis=1) / NUM_WORKLOADS perf_df = pd.concat([rho(df, tau) for tau in points], axis=1) diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py index 37db73dd4..1a15db2f5 100644 --- a/scoring/scoring_utils.py +++ b/scoring/scoring_utils.py @@ -1,10 +1,14 @@ import json import os import re +import warnings from absl import logging import pandas as pd +from scoring.scoring import NUM_TRIALS +from scoring.scoring import NUM_WORKLOADS + TRIAL_LINE_REGEX = '(.*) --- Tuning run (\d+)/(\d+) ---' METRICS_LINE_REGEX = '(.*) Metrics: ({.*})' TRIAL_DIR_REGEX = 'trial_(\d+)' @@ -103,8 +107,7 @@ def get_trials_df_dict(logfile): """ trials_dict = get_trials_dict(logfile) trials_df_dict = {} - for trial in trials_dict.keys(): - metrics = trials_dict[trial] + for trial, metrics in trials_dict.items(): trials_df_dict[trial] = pd.DataFrame(metrics) return trials_df_dict @@ -156,6 +159,10 @@ def get_experiment_df(experiment_dir): """ df = pd.DataFrame() workload_dirs = os.listdir(experiment_dir) + num_workloads = len(workload_dirs) + if num_workloads != NUM_WORKLOADS: + warnings.warn(f'There should be {NUM_WORKLOADS} workloads but there are ' + f'{num_workloads}.') for workload in workload_dirs: data = { 'workload': workload, @@ -164,6 +171,7 @@ def get_experiment_df(experiment_dir): t for t in os.listdir(os.path.join(experiment_dir, workload)) if re.match(TRIAL_DIR_REGEX, t) ] + workload_df = pd.DataFrame() for trial in trial_dirs: eval_measurements_filepath = os.path.join( experiment_dir, @@ -173,7 +181,7 @@ def get_experiment_df(experiment_dir): ) try: trial_df = pd.read_csv(eval_measurements_filepath) - except FileNotFoundError as e: + except FileNotFoundError: logging.info(f'Could not read {eval_measurements_filepath}') continue data['trial'] = trial @@ -181,5 +189,10 @@ def get_experiment_df(experiment_dir): values = trial_df[column].to_numpy() data[column] = values trial_df = pd.DataFrame([data]) - df = pd.concat([df, trial_df], ignore_index=True) + workload_df = pd.concat([workload_df, trial_df], ignore_index=True) + num_trials = len(workload_df) + if num_trials != NUM_TRIALS: + warnings.warn(f'There should be {NUM_TRIALS} trials for workload ' + f'{workload} but there are only {num_trials}.') + df = pd.concat([df, workload_df], ignore_index=True) return df diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py index b766a04d7..fbb21958c 100644 --- a/scoring/test_scoring_utils.py +++ b/scoring/test_scoring_utils.py @@ -1,8 +1,11 @@ from absl.testing import absltest -import scoring_utils -TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log' -TEST_DIR = 'test_data/experiment_dir' +from scoring import scoring_utils +from scoring.scoring import NUM_TRIALS +from scoring.scoring import NUM_WORKLOADS + +TEST_LOGFILE = 'scoring/test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log' +TEST_DIR = 'scoring/test_data/experiment_dir' NUM_EVALS = 18 @@ -14,8 +17,7 @@ def test_get_trials_dict(self): def test_get_trials_df_dict(self): trials_dict = scoring_utils.get_trials_df_dict(TEST_LOGFILE) - for trial in trials_dict: - df = trials_dict[trial] + for df in trials_dict.values(): self.assertEqual(len(df.index), NUM_EVALS) def test_get_trials_df(self): @@ -24,7 +26,18 @@ def test_get_trials_df(self): self.assertEqual(len(df.at['1', column]), NUM_EVALS) def test_get_experiment_df(self): - df = scoring_utils.get_experiment_df(TEST_DIR) + _ = scoring_utils.get_experiment_df(TEST_DIR) + self.assertWarnsRegex( + Warning, + f'There should be {NUM_WORKLOADS} workloads but there are 1.', + scoring_utils.get_experiment_df, + TEST_DIR) + self.assertWarnsRegex( + Warning, + f'There should be {NUM_TRIALS} trials for workload mnist_jax but there ' + 'are only 1.', + scoring_utils.get_experiment_df, + TEST_DIR) if __name__ == '__main__': diff --git a/submission_runner.py b/submission_runner.py index 717ea2dc4..d92732145 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -28,10 +28,16 @@ from absl import flags from absl import logging import jax -import tensorflow as tf import torch import torch.distributed as dist +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings. +import tensorflow as tf + +# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make +# it unavailable to JAX. +tf.config.set_visible_devices([], 'GPU') + from algorithmic_efficiency import checkpoint_utils from algorithmic_efficiency import halton from algorithmic_efficiency import logger_utils @@ -44,10 +50,6 @@ from algorithmic_efficiency.pytorch_utils import sync_ddp_time from algorithmic_efficiency.workloads import workloads -# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make -# it unavailable to JAX. -tf.config.set_visible_devices([], 'GPU') - # disable only for deepspeech if it works fine for other workloads. os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false' @@ -147,6 +149,9 @@ None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') +flags.DEFINE_boolean('set_pytorch_max_split_size', + False, + 'If true, set pytorch max_split_size_mb to 256') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -215,10 +220,8 @@ def train_once( model_params, model_state = workload.init_model_fn( model_init_rng, dropout_rate, aux_dropout_rate) if FLAGS.framework == 'pytorch' and FLAGS.torch_compile: - compile_error_workloads = ['ogbg', 'criteo1tb'] - eager_backend_workloads = [ - 'librispeech_conformer', 'librispeech_deepspeech' - ] + compile_error_workloads = ['librispeech_conformer', 'ogbg', 'criteo1tb'] + eager_backend_workloads = ['librispeech_deepspeech'] aot_eager_backend_workloads = [] if FLAGS.workload in compile_error_workloads: logging.warning( @@ -237,7 +240,6 @@ def train_once( else: logging.info('Performing `torch.compile`.') model_params = torch.compile(model_params) - logging.info('Initializing optimizer.') with profiler.profile('Initializing optimizer'): optimizer_state = init_optimizer_state(workload, @@ -284,7 +286,8 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng_seed, preemption_count) + meta_data = logger_utils.get_meta_data(workload, rng_seed) + logger_utils.write_json(meta_file_name, meta_data) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -340,9 +343,12 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) + # Use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < - workload.max_allowed_runtime_sec) + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): @@ -557,16 +563,18 @@ def score_submission_on_workload(workload: spec.Workload, save_checkpoints=save_checkpoints,) all_timings.append(timing) all_metrics.append(metrics) - score = min(all_timings) - for ti, _ in tuning_search_space_iter: - logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') - logging.info(f'Hyperparameters: {tuning_search_space[ti]}') - logging.info(f'Metrics: {all_metrics[ti]}') - logging.info(f'Timing: {all_timings[ti]}') - num_evals = len(all_metrics[ti]['eval_results']) + logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}') + logging.info(f'Hyperparameters: {tuning_search_space[hi]}') + logging.info(f'Metrics: {all_metrics[hi]}') + logging.info(f'Timing: {all_timings[hi]}') + num_evals = len(all_metrics[hi]['eval_results']) logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) + score = min(all_timings) else: + if tuning_search_space is not None: + raise ValueError( + 'Cannot provide a tuning search space when using self tuning.') if not rng_seed: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) @@ -597,6 +605,9 @@ def main(_): if FLAGS.workload == 'librispeech_conformer': os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.85' + if FLAGS.set_pytorch_max_split_size: + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py index 1d243d83e..d414001dd 100644 --- a/tests/modeldiffs/librispeech_conformer/compare.py +++ b/tests/modeldiffs/librispeech_conformer/compare.py @@ -38,11 +38,15 @@ def sd_transform(sd): out = {} for k in sd: if 'Attention' in ''.join(k): - if 'in_proj' in k[-1]: - new_key = k[:-1] + if 'Dense_0' in k[-2]: + # In-proj + new_key = k[:-2] chunks = sd[k].chunk(3) for t, c in zip(['query', 'key', 'value'], chunks): - out[new_key + (t, k[-1].split('_')[-1])] = c + out[new_key + (t, k[-1])] = c + elif 'Dense_1' in k[-2]: + # Out-proj + out[(*k[:-2], 'out', k[-1])] = sd[k] else: out[k] = sd[k] else: diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index ae834f1f4..5c43b233b 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -9,8 +9,8 @@ Assumes that each reference submission is using the external tuning ruleset and that it is defined in: # pylint: disable=line-too-long -"reference_algorithms/development_algorithms/{workload}/{workload}_{framework}/submission.py" -"reference_algorithms/development_algorithms/{workload}/tuning_search_space.json". +"reference_algorithms/target_setting_algorithms/{workload}/{workload}_{framework}/submission.py" +"reference_algorithms/target_setting_algorithms/{workload}/tuning_search_space.json". python3 tests/reference_algorithm_tests.py \ --workload=criteo1tb \ @@ -19,6 +19,7 @@ --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json """ + import copy import functools import importlib @@ -499,10 +500,10 @@ def _make_paths(repo_location, framework, workload_name): else: dataset_name = workload_name workload_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms/' + f'{repo_location}/reference_algorithms/target_setting_algorithms/' f'{workload_name}') search_space_path = f'{workload_dir}/tuning_search_space.json' - submission_path = (f'reference_algorithms/development_algorithms/' + submission_path = (f'reference_algorithms/target_setting_algorithms/' f'{workload_name}/{dataset_name}_{framework}/' 'submission.py') full_submission_path = f'{repo_location}/{submission_path}' @@ -534,7 +535,7 @@ def test_submission(self): if FLAGS.tuning_search_space: raise ValueError('Cannot set --tuning_search_space and --all.') references_dir = ( - f'{repo_location}/reference_algorithms/development_algorithms') + f'{repo_location}/reference_algorithms/target_setting_algorithms') for workload_name in os.listdir(references_dir): for framework in ['jax', 'pytorch']: if framework == 'pytorch':