Skip to content

Commit

Permalink
Merge branch 'dev' into UpdateTuningRules
Browse files Browse the repository at this point in the history
  • Loading branch information
fsschneider authored Nov 6, 2023
2 parents f155287 + 0943802 commit d31489e
Show file tree
Hide file tree
Showing 75 changed files with 336 additions and 2,744 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Change log

## TODO: algorithmic-efficiency 0.1.0

First release of AlgoPerf benchmarking code.
7 changes: 3 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,9 @@ pylint submission_runner.py
pylint tests
```

### Unit and integration tests

We run unit tests and integration tests as part of the of github actions as well.
You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/development_algorithms/`.
## Unit and integration tests
We run unit tests and integration tests as part of the of github actions as well.
You can also use `python tests/reference_algorithm_tests.py` to run a single model update and two model evals for each workload using the reference algorithm in `reference_algorithms/target_setting_algorithms/`.

### Regression tests

Expand Down
85 changes: 80 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,19 @@
- [Pytorch](#pytorch)
- [Rules](#rules)
- [Contributing](#contributing)
- [Note on shared data pipelines between JAX and PyTorch](#note-on-shared-data-pipelines-between-jax-and-pytorch)
- [Shared data pipelines between JAX and PyTorch](#shared-data-pipelines-between-jax-and-pytorch)
- [Setup and Platform](#setup-and-platform)
- [My machine only has one GPU. How can I use this repo?](#my-machine-only-has-one-gpu-how-can-i-use-this-repo)
- [How do I run this on my SLURM cluster?](#how-do-i-run-this-on-my-slurm-cluster)
- [How can I run this on my AWS/GCP/Azure cloud project?](#how-can-i-run-this-on-my-awsgcpazure-cloud-project)
- [Submissions](#submissions)
- [Can submission be structured using multiple files?](#can-submission-be-structured-using-multiple-files)
- [Can I install custom dependencies?](#can-i-install-custom-dependencies)
- [How can I know if my code can be run on benchmarking hardware?](#how-can-i-know-if-my-code-can-be-run-on-benchmarking-hardware)
- [Are we allowed to use our own hardware to self-report the results?](#are-we-allowed-to-use-our-own-hardware-to-self-report-the-results)




## Installation

Expand Down Expand Up @@ -146,9 +158,15 @@ To use the Docker container as an interactive virtual environment, you can run a
<docker_image_name> \
--keep_container_alive true
```

2. Open a bash terminal

Note: You may have to use double quotes around `algorithmic-efficiency` [path] in the mounting `-v` flag. If the above command fails try replacing the following line:
```bash
-v $HOME/algorithmic-efficiency:/algorithmic-efficiency2 \
```
with
```
-v $HOME"/algorithmic-efficiency:/algorithmic-efficiency" \
```
- Open a bash terminal
```bash
docker exec -it <container_id> /bin/bash
```
Expand Down Expand Up @@ -258,9 +276,66 @@ The rules for the MLCommons Algorithmic Efficency benchmark can be found in the

If you are interested in contributing to the work of the working group, feel free to [join the weekly meetings](https://mlcommons.org/en/groups/research-algorithms/), open issues. See our [CONTRIBUTING.md](CONTRIBUTING.md) for MLCommons contributing guidelines and setup and workflow instructions.

## Note on shared data pipelines between JAX and PyTorch

# Disclaimers

## Shared data pipelines between JAX and PyTorch

The JAX and PyTorch versions of the Criteo, FastMRI, Librispeech, OGBG, and WMT workloads are using the same TensorFlow input pipelines. Due to differences in how Jax and PyTorch distribute computations across devices, the PyTorch workloads have an additional overhead for these workloads.

Since we use PyTorch's [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implementation, there is one Python process for each device. Depending on the hardware and the settings of the cluster, running a TensorFlow input pipeline in each Python process can lead to errors, since too many threads are created in each process. See [this PR thread](https://github.com/mlcommons/algorithmic-efficiency/pull/85) for more details.
While this issue might not affect all setups, we currently implement a different strategy: we only run the TensorFlow input pipeline in one Python process (with `rank == 0`), and [broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast) the batches to all other devices. This introduces an additional communication overhead for each batch. See the [implementation for the WMT workload](https://github.com/mlcommons/algorithmic-efficiency/blob/main/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py#L215-L288) as an example.

## Pytorch Conformer CUDA OOM

The conformer pytorch workload may run out of memory in current state. Please set the `submission_runner.py` flag `reduce_pytorch_max_split_size` to `True` as a temporary workaround if you encounter this issue. This will set 'max_split_size_mb:256'. Note that this will adversely impact the performance of the submission on this workload. See [tracking issue](https://github.com/mlcommons/algorithmic-efficiency/issues/497).


# FAQS

## Setup and Platform

### My machine only has one GPU. How can I use this repo?
You can run this repo on a machine with an arbitrary number of GPUs. However, the default batch sizes in our reference algorithms `algorithmic-efficiency/baselines` and `algorithmic-efficiency/reference_algorithms` are tuned for a machine with 8 16GB V100 GPUs. You may run into OOMs if you run these algorithms with fewer than 8 GPUs. If you run into these issues because you are using a machine with less total GPU memory, please reduce the batch sizes for the submission. Note that your final submission must 'fit'
on the benchmarking hardware, so if you are using fewer
GPUs with higher per GPU memory, please monitor your memory usage
to make make sure it will fit on 8xV100 GPUs with 16GB of VRAM per card.

### How do I run this on my SLURM cluster?
You may run into issues with `sudo` and `docker` on a SLURM cluster. To run the workloads in a SLURM cluster you can use Apptainer (previously Singularity), see this [section](using-singularity/apptainer-instead-of-docker).
### How can I run this on my AWS/GCP/Azure cloud project?
Depending on your virtual machine, you may have to install the correct GPU drivers and the NVIDIA Docker toolkit. For example, in GCP you will have to do the following.
1. If you don't have a VM instance yet, we recommend creating a
new Compute Instance with the "Deep Learning on Linux" Image in Boot disk options.
2. To install the NVIDIA Docker toolkit, you can use `scripts/cloud-startup.sh` as a startup script for the VM. This will automate the installation of the NVIDIA GPU Drivers and NVIDIA Docker toolkit.

## Submissions
### Can submission be structured using multiple files?
Yes, your submission can be structured using multiple files.
### Can I install custom dependencies?
You may use custom dependencies as long as they do not conflict with any of the pinned packages in `algorithmic-efficiency/setup.cfg`.
To include your custom dependencies in your submission, please include them in a requirements.txt file. Please refer to the [Software dependencies](https://github.com/mlcommons/algorithmic-efficiency/blob/main/RULES.md#software-dependencies) section of our rules.
### How can I know if my code can be run on benchmarking hardware?
The benchmarking hardware specifications are documented in the [Getting Started Document](./getting_started.md).
We recommend monitoring your submission's memory usage so that it does not exceed the available memory
on the competition hardware. We also recommend to do a dry run using a cloud instance.
### Are we allowed to use our own hardware to self-report the results?
You only have to use the competition hardware for runs that are directly involved in the scoring procedure. This includes all runs for the self-tuning ruleset, but only the runs of the best hyperparameter configuration in each study for the external tuning ruleset. For example, you could use your own (different) hardware to tune your submission and identify the best hyperparameter configuration (in each study) and then only run this configuration (i.e. 5 runs, one for each study) on the competition hardware.

# Citing AlgoPerf Benchmark
If you use the **AlgoPerf** Benchmark in your work, please consider citing:

> [George E. Dahl, Frank Schneider, Zachary Nado, et al.<br/>
> **Benchmarking Neural Network Training Algorithms**<br/>
> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179)
```bibtex
@misc{dahl2023algoperf,
title={{Benchmarking Neural Network Training Algorithms}},
author={Dahl, George E. and Schneider, Frank and Nado, Zachary and Agarwal, Naman and Sastry, Chandramouli Shama and Hennig, Philipp and Medapati, Sourabh and Eschenhagen, Runa and Kasimbeg, Priya and Suo, Daniel and Bae, Juhan and Gilmer, Justin and Peirson, Abel L. and Khan, Bilal and Anil, Rohan and Rabbat, Mike and Krishnan, Shankar and Snider, Daniel and Amid, Ehsan and Chen, Kongtao and Maddison, Chris J. and Vasudev, Rakshith and Badura, Michal and Garg, Ankush and Mattson, Peter},
year={2023},
eprint={2306.07179},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```
27 changes: 13 additions & 14 deletions SUBMISSION_PROCESS_RULES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MLCommons™ AlgoPerf: Submission Process Rules

**Version:** 0.0.2 *(Last updated 03 Oktober 2023)*
**Version:** 0.0.3 *(Last updated 10 Oktober 2023)*

## Table of Contents <!-- omit from toc -->

Expand Down Expand Up @@ -39,8 +39,8 @@ Three additional documents complement this document:

### Dates

- **Publication of the call for submission: 17. Oktober 2023 (08:00 AM UTC)**
- Registration deadline for submissions: 15. December 2023 (08:00 AM UTC)
- **Publication of the call for submission: 17. October 2023 (08:00 AM UTC)**
- Registration deadline to express non-binding intent to submit: 15. December 2023 (08:00 AM UTC)
- Version freeze for the benchmark codebase: 17. January 2024 (08:00 AM UTC)
- **Submission deadline: 15. February 2024 (08:00 AM UTC)**
- Sampling the held-out workloads and hyperparameters: 16. February 2024 (08:00 AM UTC)
Expand Down Expand Up @@ -68,19 +68,18 @@ For a guide on the technical steps and details on how to write a submission, ple

In the following, we describe the logistical steps required to submit a training algorithm to the AlgoPerf: Training Algorithms Benchmark.

### Register a submission
### Register an intent to submit

All submitters need to register an intent to submit before the submission registration deadline. This registration is mandatory, i.e. required for all submissions, but not binding, i.e. you don't have to submit a registered submission. This registration is necessary, to estimate the number of submissions and provide support for potential submitters.

To register a submission, please fill out this [online form](https://forms.gle/iY1bUhwSjj1JZ4fa9) with the following information
To register an intent to submission, please fill out this [online form](https://forms.gle/iY1bUhwSjj1JZ4fa9) with the following information

- Name of the submission (e.g. name of the algorithm, or any other arbitrary identifier).
- Ruleset under which the submission will be scored.
- Name of all submitters associated with this submission.
- Email of all submitters associated with this submission.
- Affiliations of all submitters associated with this submission.
- Name, email, and affiliations of all submitters associated with this submission.
- Interest in compute support.

In return, the submission will be issued a unique **submission ID** that will be used throughout the submission process.
The submission will be issued a unique **submission ID** that will be used throughout the submission process.

### How to submit

Expand Down Expand Up @@ -143,9 +142,9 @@ The spirit jury may then hear the justifications of the submitters, inspect the

## Awards and prize money

An awards committee will award a prize for the "*Best Performance*" in each ruleset as well as a "*Jury Award*". The prize for the best-performing submission will take into account the [benchmark score](RULES.md#benchmark-score-using-performance-profiles) on the full benchmark. The "*Jury Award*" will favor more out-of-the-box ideas that show great potential, even though the method may not be of practical value with the current landscape of models, software, etc.
An awards committee will award a prize for the "*Best Performance*" in each ruleset as well as a "*Innovative Submission Award*". The prize for the best-performing submission will take into account the [benchmark score](RULES.md#benchmark-score-using-performance-profiles) on the full benchmark. The "*Innovative Submission Award*" will favor more out-of-the-box ideas that show great potential, even though the method may not be of practical value with the current landscape of models, software, etc.

The prize money for "*Best Performance*" in a ruleset is $20,000 each. The winner of the "*Jury Award*" will be awarded $10,000. We reserve the right to split the prize money and distribute it among multiple submissions.
The prize money for "*Best Performance*" in a ruleset is $20,000 each. The winner of the "*Innovative Submission Award*" will be awarded $10,000. We reserve the right to split the prize money and distribute it among multiple submissions.

If a submission is ineligible to win prize money it can still win an award. The prize money will then go to the highest-ranking eligible submission.

Expand All @@ -159,10 +158,10 @@ The awards committee will be responsible for awarding prize money to submissions

To ensure a fair process and avoid conflicts of interest, some individuals and institutions are ineligible to win prize money. This includes:

- The chairs of the MLCommons Algorithms Working Group (presently *George Dahl* and *Frank Schneider*) and their institutions (currently *Google Inc.* and the *University of Tübingen*)
- All individuals serving on the awards committee and their institutions.
- The chairs of the MLCommons Algorithms Working Group (presently *George Dahl* and *Frank Schneider*) and their associated institutions (currently *Google Inc.* and the *University of Tübingen*)
- All individuals serving on the awards committee and their associated institutions.

A submission with at least one ineligible submitter may still win an award, but the prize money will then be awarded to the top-ranked submission that is eligible for prize money.
A submission with at least one participating ineligible entity may still win an award, but the prize money will then be given to the top-ranked submission that does not contain ineligible entities.

Additionally, we require members of the spirit jury to abstain from being involved in a review if:

Expand Down
27 changes: 12 additions & 15 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import shutil
import subprocess
import sys
from typing import Any, Optional
from typing import Any, Dict, Optional

from absl import flags
from clu import metric_writers
Expand Down Expand Up @@ -96,14 +96,14 @@ def write_hparams(hparams: spec.Hyperparameters,
return hparams


def write_json(name: str, log_dict: dict, indent: int = 2) -> None:
def write_json(name: str, log_dict: Dict, indent: int = 2) -> None:
if RANK == 0:
with open(name, 'w') as f:
f.write(json.dumps(log_dict, indent=indent))


def write_to_csv(
metrics: dict,
metrics: Dict,
csv_path: str,
) -> None:
try:
Expand All @@ -120,7 +120,7 @@ def write_to_csv(
return


def _get_utilization() -> dict:
def _get_utilization() -> Dict:
util_data = {}

# CPU
Expand Down Expand Up @@ -180,7 +180,7 @@ def _get_utilization() -> dict:
return util_data


def _get_system_hardware_info() -> dict:
def _get_system_hardware_info() -> Dict:
system_hardware_info = {}
try:
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
Expand All @@ -200,7 +200,7 @@ def _get_system_hardware_info() -> dict:
return system_hardware_info


def _get_system_software_info() -> dict:
def _get_system_software_info() -> Dict:
system_software_info = {}

system_software_info['os_platform'] = \
Expand Down Expand Up @@ -243,7 +243,7 @@ def _is_primitive_type(item: Any) -> bool:
return isinstance(item, primitive)


def _get_workload_properties(workload: spec.Workload) -> dict:
def _get_workload_properties(workload: spec.Workload) -> Dict:
workload_properties = {}
skip_list = ['param_shapes', 'model_params_types']
keys = [
Expand All @@ -262,7 +262,8 @@ def _get_workload_properties(workload: spec.Workload) -> dict:
return workload_properties


def get_meta_data(workload: spec.Workload) -> dict:
def get_meta_data(workload: spec.Workload,
rng_seed: Optional[int] = None) -> Dict:
meta_data = {}
workload_properties = _get_workload_properties(workload)
meta_data.update(workload_properties)
Expand All @@ -272,15 +273,11 @@ def get_meta_data(workload: spec.Workload) -> dict:
meta_data.update(system_software_info)
system_hardware_info = _get_system_hardware_info()
meta_data.update(system_hardware_info)
if rng_seed is not None:
meta_data.update({'rng_seed': rng_seed})
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.
Expand Down Expand Up @@ -308,7 +305,7 @@ def __init__(self,
wandb.config.update(hyperparameters._asdict())

def append_scalar_metrics(self,
metrics: dict,
metrics: Dict,
global_step: int,
preemption_count: Optional[int] = None,
is_eval: bool = False) -> None:
Expand Down
3 changes: 1 addition & 2 deletions algorithmic_efficiency/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Only use CPU for Jax to avoid memory issues.
# Setting the corresponding environment variable here has no effect; it has to
# be done before jax and tensorflow (!) are imported for the first time.
jax.config.update('jax_platforms', 'cpu')
jax.config.update('jax_platform_name', 'cpu')
# From the docs: "(...) causes cuDNN to benchmark multiple convolution
# algorithms and select the fastest."
torch.backends.cudnn.benchmark = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
Expand Down Expand Up @@ -147,7 +148,8 @@ def _eval_batch(self,
batch: Dict[str, spec.Tensor]) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return self._eval_batch_pmapped(params, batch).sum()
return np.array(
self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)


class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
Expand Down
4 changes: 2 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def has_reached_validation_target(self, eval_result: Dict[str,

@property
def validation_target_value(self) -> float:
return 0.123649
return 0.123735

def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
return eval_result['test/loss'] < self.test_target_value

@property
def test_target_value(self) -> float:
return 0.126060
return 0.126041

@property
def loss_type(self) -> spec.LossType:
Expand Down
6 changes: 3 additions & 3 deletions algorithmic_efficiency/workloads/fastmri/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool:

@property
def validation_target_value(self) -> float:
return 0.7344
return 0.726999

def has_reached_test_target(self, eval_result: float) -> bool:
return eval_result['test/ssim'] > self.test_target_value

@property
def test_target_value(self) -> float:
return 0.741652
return 0.744254

@property
def loss_type(self) -> spec.LossType:
Expand All @@ -51,7 +51,7 @@ def num_validation_examples(self) -> int:

@property
def num_test_examples(self) -> int:
return 3581
return 3548

@property
def eval_batch_size(self) -> int:
Expand Down
Loading

0 comments on commit d31489e

Please sign in to comment.