forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_baselines.py
99 lines (85 loc) · 2.62 KB
/
test_baselines.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Tests for submission.py for baselines.
This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that
requires the dataset to be available.
"""
import copy
import os
import sys
from absl import flags
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from algorithmic_efficiency.profiler import PassThroughProfiler
from algorithmic_efficiency.workloads import workloads
import submission_runner
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
# (see https://github.com/google/model_search/pull/8).
FLAGS(sys.argv)
MAX_GLOBAL_STEPS = 5
baselines = {
'jax': [
'adafactor',
'adamw',
'lamb',
'momentum',
'nadamw',
'nesterov',
'sam',
'shampoo',
],
'pytorch': [
'adamw',
'momentum',
'nadamw',
'nesterov',
],
}
frameworks = [
'pytorch',
'jax',
]
baseline_path = "reference_algorithms/paper_baselines"
named_parameters = []
for f in frameworks:
for b in baselines[f]:
named_parameters.append(
dict(
testcase_name=f'{b}_{f}',
workload='mnist',
framework=f'{f}',
submission_path=f'{baseline_path}/{b}/{f}/submission.py',
tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json')
)
class BaselineTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(*named_parameters)
def test_baseline_submission(self,
workload,
framework,
submission_path,
tuning_search_space):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(workloads.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
workloads.BASE_WORKLOADS_DIR,
workload_metadata['workload_path'] + '_' + framework,
'workload.py')
workload_obj = workloads.import_workload(
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
workload_init_kwargs={})
score = submission_runner.score_submission_on_workload(
workload_obj,
workload,
submission_path,
data_dir='~/tensorflow_datasets', # The default in TFDS.
tuning_ruleset='external',
tuning_search_space=tuning_search_space,
num_tuning_trials=1,
profiler=PassThroughProfiler(),
max_global_steps=MAX_GLOBAL_STEPS,
)
logging.info(score)
if __name__ == '__main__':
absltest.main()