Skip to content

Commit

Permalink
Merge pull request #770 from mlcommons/dev
Browse files Browse the repository at this point in the history
Dev -> main
  • Loading branch information
priyakasimbeg authored Jul 2, 2024
2 parents 4b01ee6 + 2db611f commit 898ee8b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 35 deletions.
74 changes: 51 additions & 23 deletions scoring/performance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
the dictionary of submissions.
"""
import itertools
import logging
import operator
import os
import re

from absl import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tabulate import tabulate

from algorithmic_efficiency.workloads.workloads import get_base_workload_name
import algorithmic_efficiency.workloads.workloads as workloads_registry
Expand Down Expand Up @@ -63,6 +66,37 @@

MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']

#MPL params
mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = [
'Times New Roman'
] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice
mpl.rcParams['font.size'] = 22
mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures

# Plot Elements
mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed
mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(
color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
"#9467bd"]) # Example color cycle (consider ColorBrewer or viridis)
mpl.rcParams['axes.labelsize'] = 22 # Axis label font size
mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size
mpl.rcParams['ytick.labelsize'] = 20

# Legends and Gridlines
mpl.rcParams['legend.fontsize'] = 20 # Legend font size
mpl.rcParams[
'legend.loc'] = 'best' # Let matplotlib decide the best legend location
mpl.rcParams['axes.grid'] = True # Enable grid
mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency


def print_dataframe(df):
tabulated_df = tabulate(df.T, headers='keys', tablefmt='psql')
logging.info(tabulated_df)


def generate_eval_cols(metrics):
splits = ['train', 'validation']
Expand Down Expand Up @@ -177,11 +211,13 @@ def get_workloads_time_to_target(submission,
num_trials = len(group)
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
if strict:
raise ValueError(f'Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
raise ValueError(
f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
else:
logging.warning(f'Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')
logging.warning(
f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
f'{workload} but found {num_trials} trials.')

# Get trial and time index that reaches target
trial_idx, time_idx = get_best_trial_index(
Expand All @@ -194,13 +230,12 @@ def get_workloads_time_to_target(submission,

workloads.append({
'submission': submission_name,
'workload': workload,
'workload': re.sub(r'_(jax|pytorch)$', '', workload),
time_col: np.median(time_vals_per_study),
})

df = pd.DataFrame.from_records(workloads)
df = df.pivot(index='submission', columns='workload', values=time_col)

return df


Expand Down Expand Up @@ -276,19 +311,13 @@ def compute_performance_profiles(submissions,
# For each held-out workload if variant target was not hit set submission to inf
framework = None
for workload in df.keys():
# Check if this is a variant
framework = workload.split('_')[-1]
workload_ = workload.split(f'_{framework}')[0]
if workload_ not in BASE_WORKLOADS:
if workload not in BASE_WORKLOADS:
# If variants do not have finite score set base_workload score to inf
base_workload = get_base_workload_name(workload_)
base_workload = get_base_workload_name(workload)
df[base_workload] = df.apply(
variant_criteria_filter(base_workload + f'_{framework}', workload),
axis=1)
variant_criteria_filter(base_workload, workload), axis=1)

base_workloads = [w + f'_{framework}' for w in BASE_WORKLOADS]
df = df[base_workloads]
print(df)
df = df[BASE_WORKLOADS]

if verbosity > 0:
logging.info('\n`{time_col}` to reach target:')
Expand Down Expand Up @@ -375,8 +404,7 @@ def plot_performance_profiles(perf_df,
df_col,
scale='linear',
save_dir=None,
figsize=(30, 10),
font_size=18):
figsize=(30, 10)):
"""Plot performance profiles.
Args:
Expand All @@ -396,12 +424,12 @@ def plot_performance_profiles(perf_df,
Returns:
None. If a valid save_dir is provided, save both the plot and perf_df.
"""
fig = perf_df.T.plot(figsize=figsize)
fig = perf_df.T.plot(figsize=figsize, alpha=0.7)
df_col_display = f'log10({df_col})' if scale == 'log' else df_col
fig.set_xlabel(
f'Ratio of `{df_col_display}` to best submission', size=font_size)
fig.set_ylabel('Proportion of workloads', size=font_size)
fig.legend(prop={'size': font_size}, bbox_to_anchor=(1.0, 1.0))
fig.set_xlabel(f'Ratio of `{df_col_display}` to best submission')
fig.set_ylabel('Proportion of workloads')
fig.legend(bbox_to_anchor=(1.0, 1.0))
plt.tight_layout()
maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}')
maybe_save_df_to_csv(save_dir,
perf_df,
Expand Down
61 changes: 49 additions & 12 deletions scoring/score_submissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
--compute_performance_profiles
"""

import json
import operator
import os
import pickle

from absl import app
from absl import flags
from absl import logging
import numpy as np
import pandas as pd
import performance_profile
import scoring_utils
from tabulate import tabulate

from scoring import performance_profile

flags.DEFINE_string(
'submission_directory',
None,
Expand All @@ -45,6 +46,16 @@
'self_tuning_ruleset',
False,
'Whether to score on self-tuning ruleset or externally tuned ruleset')
flags.DEFINE_string(
'save_results_to_filename',
None,
'Filename to save the processed results that are fed into the performance profile functions.'
)
flags.DEFINE_boolean(
'load_results_from_filename',
None,
'Filename to load processed results from that are fed into performance profile functions'
)
FLAGS = flags.FLAGS


Expand Down Expand Up @@ -101,8 +112,13 @@ def get_summary_df(workload, workload_df, include_test_split=False):
return summary_df


def print_submission_summary(df, include_test_split=True):
def get_submission_summary(df, include_test_split=True):
"""Summarizes the submission results into metric and time tables
organized by workload.
"""

dfs = []
print(df)
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(
workload, group, include_test_split=include_test_split)
Expand All @@ -115,15 +131,36 @@ def print_submission_summary(df, include_test_split=True):

def main(_):
results = {}

for submission in os.listdir(FLAGS.submission_directory):
experiment_path = os.path.join(FLAGS.submission_directory, submission)
df = scoring_utils.get_experiment_df(experiment_path)
results[submission] = df
summary_df = print_submission_summary(df)
with open(os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
'w') as fout:
summary_df.to_csv(fout)
os.makedirs(FLAGS.output_dir, exist_ok=True)

# Optionally read results to filename
if FLAGS.load_results_from_filename:
with open(
os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
'rb') as f:
results = pickle.load(f)
else:
for team in os.listdir(FLAGS.submission_directory):
for submission in os.listdir(
os.path.join(FLAGS.submission_directory, team)):
print(submission)
experiment_path = os.path.join(FLAGS.submission_directory,
team,
submission)
df = scoring_utils.get_experiment_df(experiment_path)
results[submission] = df
summary_df = get_submission_summary(df)
with open(
os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
'w') as fout:
summary_df.to_csv(fout)

# Optionally save results to filename
if FLAGS.save_results_to_filename:
with open(
os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
'wb') as f:
pickle.dump(results, f)

if not FLAGS.strict:
logging.warning(
Expand Down

0 comments on commit 898ee8b

Please sign in to comment.