Skip to content

Commit

Permalink
Merge pull request #575 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Nov 15, 2023
2 parents 0c5a61f + a74e52c commit d59cf2b
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 65 deletions.
13 changes: 9 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ You can install this package and dependences in a [python virtual environment](#
pip3 install -e '.[jax_cpu]'
pip3 install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'
pip3 install -e '.[full]'
```
```

### Python virtual environment

Note: Python minimum requirement >= 3.8
#### Prerequisites
- Python minimum requirement >= 3.8
- CUDA 11.8
- NVIDIA Driver version 535.104.05

To set up a virtual enviornment and install this repository

Expand Down Expand Up @@ -115,12 +118,14 @@ pip3 install -e '.[full]'

### Docker

We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
Alternatively, a Singularity/Apptainer container can also be used (see instructions below).

We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.

**Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs.
#### Prerequisites
- NVIDIA Driver version 535.104.05
- NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs.
See instructions [here](https://github.com/NVIDIA/nvidia-docker).

#### Building Docker Image
Expand Down
6 changes: 3 additions & 3 deletions algorithmic_efficiency/workloads/fastmri/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ def has_reached_validation_target(self, eval_result: float) -> bool:

@property
def validation_target_value(self) -> float:
return 0.726999
return 0.727120

def has_reached_test_target(self, eval_result: float) -> bool:
return eval_result['test/ssim'] > self.test_target_value

@property
def test_target_value(self) -> float:
return 0.744254
return 0.744296

@property
def loss_type(self) -> spec.LossType:
Expand All @@ -51,7 +51,7 @@ def num_validation_examples(self) -> int:

@property
def num_test_examples(self) -> int:
return 3548
return 3581

@property
def eval_batch_size(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Therefore, you will have to specify the framework (pytorch or jax) through thefr

```bash
python3 datasets/dataset_setup.py \
--data_dir /data \
--data_dir $DATA_DIR \
--imagenet \
--temp_dir $DATA_DIR/tmp \
--imagenet_train_url <imagenet_train_url> \
Expand Down Expand Up @@ -133,7 +133,7 @@ downloading has finished.
To download, train a tokenizer and preprocess the librispeech dataset:
```bash
python3 datasets/dataset_setup.py \
--data_dir librispeech \
--data_dir $DATA_DIR \
--temp_dir $DATA_DIR/tmp \
--librispeech
```
Expand Down
20 changes: 17 additions & 3 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def setup_fastmri(data_dir, src_data_dir):


def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url):
"""Downloads and returns the download dir."""
imagenet_train_filepath = os.path.join(data_dir, IMAGENET_TRAIN_TAR_FILENAME)
imagenet_val_filepath = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME)

Expand Down Expand Up @@ -506,7 +507,20 @@ def setup_imagenet_pytorch(data_dir):
val_tar_file_path = os.path.join(data_dir, IMAGENET_VAL_TAR_FILENAME)
test_dir_path = os.path.join(data_dir, 'imagenet_v2')

# Setup jax dataset dir
# Check if downloaded data has been moved
manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual')
if not os.path.exists(train_tar_file_path):
if os.path.exists(
os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)):
train_tar_file_path = os.path.join(manual_download_dir,
IMAGENET_TRAIN_TAR_FILENAME)
if not os.path.exists(val_tar_file_path):
if os.path.exists(
os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)):
val_tar_file_path = os.path.join(manual_download_dir,
IMAGENET_VAL_TAR_FILENAME)

# Setup pytorch dataset dir
imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch')
os.makedirs(imagenet_pytorch_data_dir)
os.makedirs(os.path.join(imagenet_pytorch_data_dir, 'train'))
Expand All @@ -519,9 +533,9 @@ def setup_imagenet_pytorch(data_dir):
logging.info('Moving {} to {}'.format(val_tar_file_path,
imagenet_pytorch_data_dir))
shutil.move(val_tar_file_path, imagenet_pytorch_data_dir)
if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')):
if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')):
logging.info('Moving imagenet_v2 to {}'.format(
os.path.join(imagenet_jax_data_dir, 'imagenet_v2')))
os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')))
shutil.move(test_dir_path,
os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))

Expand Down
46 changes: 20 additions & 26 deletions scoring/scoring.py → scoring/performance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@
The keys in this dictionary should match the workload identifiers used in
the dictionary of submissions.
"""

import itertools
import operator
import os
import re

from absl import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import algorithmic_efficiency.workloads.workloads as workloads_registry
from scoring import scoring_utils

WORKLOADS = workloads_registry.WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
Expand Down Expand Up @@ -153,7 +154,8 @@ def get_index_that_reaches_target(workload_df,
def get_times_for_submission(submission,
submission_tag,
time_col='global_step',
verbosity=1):
verbosity=1,
self_tuning_ruleset=False):
"""Get times to target for each workload in a submission.
Args:
Expand All @@ -168,25 +170,16 @@ def get_times_for_submission(submission,
"""
workloads = []
submission_name = submission_tag.split('.')[1]

num_workloads = len(submission.groupby('workload'))
if num_workloads != NUM_WORKLOADS:
logging.warning(f'Expecting {NUM_WORKLOADS} workloads '
f'but found {num_workloads} workloads.')
for workload, group in submission.groupby('workload'):
workload_name = re.match(WORKLOAD_NAME_PATTERN, workload).group(1)
framework = re.match(WORKLOAD_NAME_PATTERN, workload).group(2)
workload_metadata = WORKLOADS[workload_name]

# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
BASE_WORKLOADS_DIR,
workload_metadata['workload_path'] + f'{framework}',
'workload.py')
workload_init_kwargs = {}
workload_obj = workloads_registry.import_workload(
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
workload_init_kwargs=workload_init_kwargs)
metric_name = workload_obj.target_metric_name
validation_metric = f'validation/{metric_name}'
validation_target = workload_obj.validation_target_value
num_trials = len(group)
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
logging.warning(f'Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload)

trial_idx, time_idx = get_index_that_reaches_target(
group, validation_metric, validation_target)
Expand Down Expand Up @@ -250,21 +243,22 @@ def compute_performance_profiles(results,
dfs = []

for submission_tag, result in results.items():
print(f'\nComputing performance profile with respect to `{time_col}` for '
f'{submission_tag}')
logging.info(
f'\nComputing performance profile with respect to `{time_col}` for '
f'{submission_tag}')
dfs.append(
get_times_for_submission(result, submission_tag, time_col, verbosity))
df = pd.concat(dfs)

if verbosity > 0:
print(f'\n`{time_col}` to reach target:')
logging.info('\n`{time_col}` to reach target:')
with pd.option_context('display.max_rows',
None,
'display.max_columns',
None,
'display.width',
1000):
print(df)
logging.info(df)

# Divide by the fastest.
if reference_submission_tag is None:
Expand All @@ -273,14 +267,14 @@ def compute_performance_profiles(results,
df.update(df.div(df.loc[reference_submission_tag, :], axis=1))

if verbosity > 0:
print(f'\n`{time_col}` to reach target normalized to best:')
logging.info('\n`{time_col}` to reach target normalized to best:')
with pd.option_context('display.max_rows',
None,
'display.max_columns',
None,
'display.width',
1000):
print(df)
logging.info(df)

# If no max_tau is supplied, choose the value of tau that would plot all non
# inf or nan data.
Expand Down
77 changes: 61 additions & 16 deletions scoring/score_submission.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import operator
import os

from absl import app
from absl import flags
from absl import logging
import numpy as np
import pandas as pd
import scoring_utils
from tabulate import tabulate

from scoring import scoring
from scoring import performance_profile

flags.DEFINE_string(
'experiment_path',
Expand All @@ -15,29 +19,70 @@
flags.DEFINE_string('output_dir',
'scoring_results',
'Path to save performance profile table and plot.')
flags.DEFINE_boolean('compute_performance_profiles',
False,
'Whether or not to compute the performance profiles.')
FLAGS = flags.FLAGS


def get_summary_df(workload, workload_df):
validation_metric, validation_target = scoring_utils.get_workload_validation_target(workload)
is_minimized = performance_profile.check_if_minimized(validation_metric)
target_op = operator.le if is_minimized else operator.ge
best_op = min if is_minimized else max
idx_op = np.argmin if is_minimized else np.argmax

summary_df = pd.DataFrame()
summary_df['workload'] = workload_df['workload']
summary_df['trial'] = workload_df['trial']
summary_df['target metric name'] = validation_metric
summary_df['target metric value'] = validation_target

summary_df['target reached'] = workload_df[validation_metric].apply(
lambda x: target_op(x, validation_target)).apply(np.any)
summary_df['best target'] = workload_df[validation_metric].apply(
lambda x: best_op(x))
workload_df['index best eval'] = workload_df[validation_metric].apply(
lambda x: idx_op(x))
summary_df['submission time'] = workload_df.apply(
lambda x: x['accumulated_submission_time'][x['index best eval']], axis=1)
summary_df['score'] = summary_df.apply(
lambda x: x['submission time'] if x['target reached'] else np.inf, axis=1)

return summary_df


def main(_):
df = scoring_utils.get_experiment_df(FLAGS.experiment_path)
results = {
FLAGS.submission_tag: df,
}
performance_profile_df = scoring.compute_performance_profiles(
results,
time_col='score',
min_tau=1.0,
max_tau=None,
reference_submission_tag=None,
num_points=100,
scale='linear',
verbosity=0)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
scoring.plot_performance_profiles(
performance_profile_df, 'score', save_dir=FLAGS.output_dir)

logging.info(performance_profile_df)

dfs = []
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(workload, group)
dfs.append(summary_df)

df = pd.concat(dfs)
logging.info(tabulate(df, headers='keys', tablefmt='psql'))

if FLAGS.compute_performance_profiles:
performance_profile_df = performance_profile.compute_performance_profiles(
results,
time_col='score',
min_tau=1.0,
max_tau=None,
reference_submission_tag=None,
num_points=100,
scale='linear',
verbosity=0)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
performance_profile_df, 'score', save_dir=FLAGS.output_dir)
perf_df = tabulate(
performance_profile_df.T, headers='keys', tablefmt='psql')
logging.info(f'Performance profile:\n {perf_df}')


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit d59cf2b

Please sign in to comment.