diff --git a/.gitignore b/.gitignore index 95d9fa6c1..d2e212366 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ algorithmic_efficiency/workloads/librispeech_conformer/work_dir *.vocab wandb/ *.txt +scoring/plots/ !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv !scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv \ No newline at end of file diff --git a/CALL_FOR_SUBMISSIONS.md b/CALL_FOR_SUBMISSIONS.md index 0e21f0e9c..c5fff46d7 100644 --- a/CALL_FOR_SUBMISSIONS.md +++ b/CALL_FOR_SUBMISSIONS.md @@ -17,7 +17,7 @@ Submissions can compete under two hyperparameter tuning rulesets (with separate - **Registration deadline to express non-binding intent to submit: February 28th, 2024**.\ Please fill out the (mandatory but non-binding) [**registration form**](https://forms.gle/K7ty8MaYdi2AxJ4N8). - **Submission deadline: April 04th, 2024** *(moved by a week from the initial March 28th, 2024)* -- [tentative] Announcement of all results: July 15th, 2024 +- [Announcement of all results](https://mlcommons.org/2024/08/mlc-algoperf-benchmark-competition/): August 1st, 2024 For a detailed and up-to-date timeline see the [Competition Rules](/COMPETITION_RULES.md). diff --git a/README.md b/README.md index f778c94a6..5a1f10a33 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,7 @@ --- > [!IMPORTANT] -> Submitters are no longer required to self-report results. -> We are currently in the process of evaluating and scoring received submissions. -> We are aiming to release results by July 15th 2024. -> For other key dates please see [Call for Submissions](CALL_FOR_SUBMISSIONS.md). +> The results of the inaugural AlgoPerf: Training Algorithms benchmark competition have been announced. See the [MLCommons blog post](https://mlcommons.org/2024/08/mlc-algoperf-benchmark-competition/) for an overview and the [results page](https://mlcommons.org/benchmarks/algorithms/) for more details on the results. We are currently preparing an in-depth analysis of the results in the form of a paper and plan the next iteration of the benchmark competition. ## Table of Contents diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py new file mode 100644 index 000000000..5fb5f259d --- /dev/null +++ b/scoring/compute_speedups.py @@ -0,0 +1,112 @@ +"""File to compute speedups (i.e. geometric means between runtimes).""" + +import pickle + +from absl import app +from absl import flags +import numpy as np +import pandas as pd +from performance_profile import BASE_WORKLOADS +from performance_profile import get_workloads_time_to_target +from scipy import stats + +flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.') +flags.DEFINE_string( + 'base', + 'prize_qualification_baseline', + 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.' +) +flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.') +flags.DEFINE_boolean('self_tuning_ruleset', + False, + 'Whether the self-tuning ruleset is being scored.') +flags.DEFINE_boolean('save_results', + False, + 'Whether to save the results to disk.') +FLAGS = flags.FLAGS + +MAX_BUDGETS = { + 'criteo1tb': 7703, + 'fastmri': 8859, + 'imagenet_resnet': 63_008, + 'imagenet_vit': 77_520, + 'librispeech_conformer': 61_068, + 'librispeech_deepspeech': 55_506, + 'ogbg': 18_477, + 'wmt': 48_151, +} + + +def replace_inf(row): + """Replace ifs with maximum runtime budget (+1 second). + + Args: + row (pd.Series): The original row. + + Returns: + pd.Series: The row with infs replaced. + """ + workload_name = row.name + # Factor of 3 for self-tuning ruleset + factor = 3 if FLAGS.self_tuning_ruleset else 1 + max_runtime_workload = factor * MAX_BUDGETS[workload_name] + row.replace(np.inf, max_runtime_workload + 1, inplace=True) + return row + + +def compute_speedup(): + """Compute speedup between two algorithms.""" + # Load results from disk + with open(FLAGS.results_txt, 'rb') as f: + results = pickle.load(f) + + # Compute median over runtimes for both training algorithms + base_results = get_workloads_time_to_target( + results[FLAGS.base], + FLAGS.base, + time_col="score", + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + ) + comparison_results = get_workloads_time_to_target( + results[FLAGS.comparison], + FLAGS.comparison, + time_col="score", + self_tuning_ruleset=FLAGS.self_tuning_ruleset, + ) + + # Merge results + merged_results = pd.concat([base_results, comparison_results]).transpose() + + # Ignore workload variants (only consider base workloads) for speedup + merged_results = merged_results.loc[merged_results.index.isin(BASE_WORKLOADS)] + + # Replace infs with maximum runtime budget (+1 second) + merged_results = merged_results.apply(replace_inf, axis=1) + + # Compute speedup + merged_results['speedup'] = merged_results[ + f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}'] + speedups = merged_results['speedup'].to_numpy() + mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups + + print(merged_results, end='\n\n') + print( + f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}" + ) + + if FLAGS.save_results: + # Optionally save results to disk + print("Saving results to disk...") + filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv' + merged_results.to_csv(filename) + + +def main(_): + """Main function to compute speedup between two algorithms.""" + compute_speedup() + + +if __name__ == '__main__': + flags.mark_flag_as_required('results_txt') + flags.mark_flag_as_required('comparison') + app.run(main) diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py index 8ee271804..32acae9ab 100644 --- a/scoring/performance_profile.py +++ b/scoring/performance_profile.py @@ -26,14 +26,17 @@ the dictionary of submissions. """ import itertools +import json import operator import os import re from absl import logging +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np import pandas as pd +from tabulate import tabulate from algorithmic_efficiency.workloads.workloads import get_base_workload_name import algorithmic_efficiency.workloads.workloads as workloads_registry @@ -43,6 +46,10 @@ BASE_WORKLOADS = workloads_registry.BASE_WORKLOADS WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)' BASE_WORKLOADS_DIR = 'algorithmic_efficiency/workloads/' +# Open json file to read heldout workloads +# TODO: This probably shouldn't be hardcoded but passed as an argument. +with open("held_out_workloads_algoperf_v05.json", "r") as f: + HELDOUT_WORKLOADS = json.load(f) # These global variables have to be set according to the current set of # workloads and rules for the scoring to be correct. # We do not use the workload registry since it contains test and development @@ -63,6 +70,37 @@ MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu'] +#MPL params +mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches +mpl.rcParams['font.family'] = 'serif' +mpl.rcParams['font.serif'] = [ + 'Times New Roman' +] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice +mpl.rcParams['font.size'] = 22 +mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures + +# Plot Elements +mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed +mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed +mpl.rcParams['axes.prop_cycle'] = mpl.cycler( + color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", + "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis) +mpl.rcParams['axes.labelsize'] = 22 # Axis label font size +mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size +mpl.rcParams['ytick.labelsize'] = 20 + +# Legends and Gridlines +mpl.rcParams['legend.fontsize'] = 20 # Legend font size +mpl.rcParams[ + 'legend.loc'] = 'best' # Let matplotlib decide the best legend location +mpl.rcParams['axes.grid'] = True # Enable grid +mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency + + +def print_dataframe(df): + tabulated_df = tabulate(df.T, headers='keys', tablefmt='psql') + logging.info(tabulated_df) + def generate_eval_cols(metrics): splits = ['train', 'validation'] @@ -150,10 +188,10 @@ def get_workloads_time_to_target(submission, if strict: raise ValueError( f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads.') + f'but found {num_workloads} workloads for {submission_name}.') logging.warning( f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads ' - f'but found {num_workloads} workloads.') + f'but found {num_workloads} workloads for {submission_name}.') # For each workload get submission time get the submission times to target. for workload, group in submission.groupby('workload'): @@ -164,11 +202,13 @@ def get_workloads_time_to_target(submission, num_studies = len(group.groupby('study')) if num_studies != NUM_STUDIES: if strict: - raise ValueError(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + raise ValueError(f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.') else: - logging.warning(f'Expecting {NUM_STUDIES} trials for workload ' - f'{workload} but found {num_studies} trials.') + logging.warning(f'Expecting {NUM_STUDIES} studies for workload ' + f'{workload} but found {num_studies} studies ' + f'for {submission_name}.') # For each study check trials for study, group in group.groupby('study'): @@ -177,11 +217,15 @@ def get_workloads_time_to_target(submission, num_trials = len(group) if num_trials != NUM_TRIALS and not self_tuning_ruleset: if strict: - raise ValueError(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + raise ValueError( + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.') else: - logging.warning(f'Expecting {NUM_TRIALS} trials for workload ' - f'{workload} but found {num_trials} trials.') + logging.warning( + f'In Study {study}: Expecting {NUM_TRIALS} trials for workload ' + f'{workload} but found {num_trials} trials ' + f'for {submission_name}.') # Get trial and time index that reaches target trial_idx, time_idx = get_best_trial_index( @@ -194,13 +238,12 @@ def get_workloads_time_to_target(submission, workloads.append({ 'submission': submission_name, - 'workload': workload, + 'workload': re.sub(r'_(jax|pytorch)$', '', workload), time_col: np.median(time_vals_per_study), }) df = pd.DataFrame.from_records(workloads) df = df.pivot(index='submission', columns='workload', values=time_col) - return df @@ -210,6 +253,9 @@ def filter(x): try: if x[variant_workload] == np.inf: return np.inf + # Also check for nan values (e.g. OOMs) + elif np.isnan(x[variant_workload]): + return np.inf else: return x[base_workload] except KeyError as e: @@ -268,27 +314,33 @@ def compute_performance_profiles(submissions, self_tuning_ruleset, strict)) df = pd.concat(dfs) + # Restrict to base and sampled held-out workloads + # (ignore the additional workload variants of the baseline + # as they cause issues when checking for nans in workload variants). + df = df[BASE_WORKLOADS + HELDOUT_WORKLOADS] + # Sort workloads alphabetically (for better display) + df = df.reindex(sorted(df.columns), axis=1) + + # For each held-out workload set to inf if the base workload is inf or nan + for workload in df.keys(): + if workload not in BASE_WORKLOADS: + # If base do not have finite score set variant score to inf + base_workload = get_base_workload_name(workload) + df[workload] = df.apply( + variant_criteria_filter(workload, base_workload), axis=1) # Set score to inf if not within 4x of fastest submission best_scores = df.min(axis=0) df[df.apply(lambda x: x > 4 * best_scores, axis=1)] = np.inf - # For each held-out workload if variant target was not hit set submission to inf - framework = None + # For each base workload if variant target was not hit set submission to inf for workload in df.keys(): - # Check if this is a variant - framework = workload.split('_')[-1] - workload_ = workload.split(f'_{framework}')[0] - if workload_ not in BASE_WORKLOADS: + if workload not in BASE_WORKLOADS: # If variants do not have finite score set base_workload score to inf - base_workload = get_base_workload_name(workload_) + base_workload = get_base_workload_name(workload) df[base_workload] = df.apply( - variant_criteria_filter(base_workload + f'_{framework}', workload), - axis=1) - - base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS] - df = df[base_workloads] - print(df) + variant_criteria_filter(base_workload, workload), axis=1) + df = df[BASE_WORKLOADS] if verbosity > 0: logging.info('\n`{time_col}` to reach target:') @@ -375,8 +427,7 @@ def plot_performance_profiles(perf_df, df_col, scale='linear', save_dir=None, - figsize=(30, 10), - font_size=18): + figsize=(30, 10)): """Plot performance profiles. Args: @@ -396,12 +447,12 @@ def plot_performance_profiles(perf_df, Returns: None. If a valid save_dir is provided, save both the plot and perf_df. """ - fig = perf_df.T.plot(figsize=figsize) + fig = perf_df.T.plot(figsize=figsize, alpha=0.7) df_col_display = f'log10({df_col})' if scale == 'log' else df_col - fig.set_xlabel( - f'Ratio of `{df_col_display}` to best submission', size=font_size) - fig.set_ylabel('Proportion of workloads', size=font_size) - fig.legend(prop={'size': font_size}, bbox_to_anchor=(1.0, 1.0)) + fig.set_xlabel(f'Ratio of `{df_col_display}` to best submission') + fig.set_ylabel('Proportion of workloads') + fig.legend(bbox_to_anchor=(1.0, 1.0)) + plt.tight_layout() maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}') maybe_save_df_to_csv(save_dir, perf_df, diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py index 0b768855e..1fb39d193 100644 --- a/scoring/score_submissions.py +++ b/scoring/score_submissions.py @@ -14,24 +14,26 @@ import operator import os +import pickle from absl import app from absl import flags from absl import logging import numpy as np import pandas as pd +import performance_profile import scoring_utils from tabulate import tabulate -from scoring import performance_profile - flags.DEFINE_string( 'submission_directory', None, 'Path to submission directory containing experiment directories.') -flags.DEFINE_string('output_dir', - 'scoring_results', - 'Path to save performance profile table and plot.') +flags.DEFINE_string( + 'output_dir', + 'scoring_results', + 'Path to save performance profile artifacts, submission_summaries and results files.' +) flags.DEFINE_boolean('compute_performance_profiles', False, 'Whether or not to compute the performance profiles.') @@ -45,6 +47,21 @@ 'self_tuning_ruleset', False, 'Whether to score on self-tuning ruleset or externally tuned ruleset') +flags.DEFINE_string( + 'save_results_to_filename', + None, + 'Filename to save the processed results that are fed into the performance profile functions.' +) +flags.DEFINE_string( + 'load_results_from_filename', + None, + 'Filename to load processed results from that are fed into performance profile functions' +) +flags.DEFINE_string( + 'exclude_submissions', + '', + 'Optional comma seperated list of names of submissions to exclude from scoring.' +) FLAGS = flags.FLAGS @@ -71,9 +88,15 @@ def get_summary_df(workload, workload_df, include_test_split=False): summary_df['time to best eval on val (s)'] = workload_df.apply( lambda x: x['accumulated_submission_time'][x['index best eval on val']], axis=1) - summary_df['time to target on val (s)'] = summary_df.apply( - lambda x: x['time to best eval on val (s)'] - if x['val target reached'] else np.inf, + workload_df['val target reached'] = workload_df[validation_metric].apply( + lambda x: target_op(x, validation_target)).apply(np.any) + workload_df['index to target on val'] = workload_df.apply( + lambda x: np.argmax(target_op(x[validation_metric], validation_target)) + if x['val target reached'] else np.nan, + axis=1) + summary_df['time to target on val (s)'] = workload_df.apply( + lambda x: x['accumulated_submission_time'][int(x[ + 'index to target on val'])] if x['val target reached'] else np.inf, axis=1) # test metrics @@ -101,8 +124,13 @@ def get_summary_df(workload, workload_df, include_test_split=False): return summary_df -def print_submission_summary(df, include_test_split=True): +def get_submission_summary(df, include_test_split=True): + """Summarizes the submission results into metric and time tables + organized by workload. + """ + dfs = [] + print(df) for workload, group in df.groupby('workload'): summary_df = get_summary_df( workload, group, include_test_split=include_test_split) @@ -113,17 +141,56 @@ def print_submission_summary(df, include_test_split=True): return df +def compute_leaderboard_score(df, normalize=True): + """Compute leaderboard score by taking integral of performance profile. + + Args: + df: pd.DataFrame returned from `compute_performance_profiles`. + normalize: divide by the range of the performance profile's tau. + + Returns: + pd.DataFrame with one column of scores indexed by submission. + """ + scores = np.trapz(df, x=df.columns) + if normalize: + scores /= df.columns.max() - df.columns.min() + return pd.DataFrame(scores, columns=['score'], index=df.index) + + def main(_): results = {} - - for submission in os.listdir(FLAGS.submission_directory): - experiment_path = os.path.join(FLAGS.submission_directory, submission) - df = scoring_utils.get_experiment_df(experiment_path) - results[submission] = df - summary_df = print_submission_summary(df) - with open(os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), - 'w') as fout: - summary_df.to_csv(fout) + os.makedirs(FLAGS.output_dir, exist_ok=True) + + # Optionally read results to filename + if FLAGS.load_results_from_filename: + with open( + os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), + 'rb') as f: + results = pickle.load(f) + else: + for team in os.listdir(FLAGS.submission_directory): + for submission in os.listdir( + os.path.join(FLAGS.submission_directory, team)): + print(submission) + if submission in FLAGS.exclude_submissions.split(','): + continue + experiment_path = os.path.join(FLAGS.submission_directory, + team, + submission) + df = scoring_utils.get_experiment_df(experiment_path) + results[submission] = df + summary_df = get_submission_summary(df) + with open( + os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), + 'w') as fout: + summary_df.to_csv(fout) + + # Optionally save results to filename + if FLAGS.save_results_to_filename: + with open( + os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), + 'wb') as f: + pickle.dump(results, f) if not FLAGS.strict: logging.warning( @@ -137,7 +204,7 @@ def main(_): results, time_col='score', min_tau=1.0, - max_tau=None, + max_tau=4.0, reference_submission_tag=None, num_points=100, scale='linear', @@ -148,9 +215,13 @@ def main(_): os.mkdir(FLAGS.output_dir) performance_profile.plot_performance_profiles( performance_profile_df, 'score', save_dir=FLAGS.output_dir) - perf_df = tabulate( + performance_profile_str = tabulate( performance_profile_df.T, headers='keys', tablefmt='psql') - logging.info(f'Performance profile:\n {perf_df}') + logging.info(f'Performance profile:\n {performance_profile_str}') + scores = compute_leaderboard_score(performance_profile_df) + scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv')) + scores_str = tabulate(scores, headers='keys', tablefmt='psql') + logging.info(f'Scores: \n {scores_str}') if __name__ == '__main__':