forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubmission_runner_test.py
85 lines (74 loc) · 2.85 KB
/
submission_runner_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Tests for submission_runner.py.
This is an end-to-end test for MNIST in PyTorch and Jax that requires the
dataset to be available. For testing the workload and reference submission code
for all workloads, see reference_algorithm_tests.py.
"""
import copy
import os
import sys
from absl import flags
from absl import logging
from absl.testing import absltest
from absl.testing import parameterized
from algorithmic_efficiency.profiler import PassThroughProfiler
import submission_runner
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
# (see https://github.com/google/model_search/pull/8).
FLAGS(sys.argv)
_MNIST_DEV_ALGO_DIR = 'reference_algorithms/development_algorithms/mnist'
class SubmissionRunnerTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(
dict(
testcase_name='mnist_jax',
workload='mnist',
framework='jax',
submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'),
tuning_search_space=(
f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
dict(
testcase_name='mnist_pytorch',
workload='mnist',
framework='pytorch',
submission_path=(
f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'),
tuning_search_space=(
f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
)
def test_submission(self,
workload,
framework,
submission_path,
tuning_search_space):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
submission_runner.BASE_WORKLOADS_DIR,
workload_metadata['workload_path'] + '_' + framework,
'workload.py')
workload_obj = submission_runner.import_workload(
workload_path=workload_metadata['workload_path'],
workload_class_name=workload_metadata['workload_class_name'],
workload_init_kwargs={})
score = submission_runner.score_submission_on_workload(
workload_obj,
workload,
submission_path,
data_dir='~/tensorflow_datasets', # The default in TFDS.
tuning_ruleset='external',
tuning_search_space=tuning_search_space,
num_tuning_trials=1,
profiler=PassThroughProfiler(),
max_global_steps=500,
)
logging.info(score)
def test_convert_filepath_to_module(self):
"""Sample test for the `convert_filepath_to_module` function."""
test_path = os.path.abspath(__file__)
module_path = submission_runner.convert_filepath_to_module(test_path)
self.assertNotIn('.py', module_path)
self.assertNotIn('/', module_path)
self.assertIsInstance(module_path, str)
if __name__ == '__main__':
absltest.main()