forked from mlcommons/logging
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathseed_checker.py
148 lines (122 loc) · 5.06 KB
/
seed_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import warnings
import os
import logging
from ..compliance_checker import mlp_parser
# What are source files?
SOURCE_FILE_EXT = {
'.py', '.cc', '.cpp', '.cxx', '.c', '.h', '.hh', '.hpp', '.hxx', '.sh',
'.sub', '.cu', '.cuh'
}
def _print_divider_bar():
logging.info('------------------------------')
def is_source_file(path):
""" Check if a file is considered as a "source file" by extensions.
The extensions that are considered as "source file" are listed in
SOURCE_FILE_EXT.
Args:
path: The absolute path, relative path or name to/of the file.
"""
return os.path.splitext(path)[1].lower() in SOURCE_FILE_EXT
def find_source_files_under(path):
""" Find all source files in all sub-directories under a directory.
Args:
path: The absolute or relative path to the directory under query.
"""
source_files = []
for root, subdirs, files in os.walk(path):
for file_name in files:
if is_source_file(file_name):
source_files.append(os.path.join(root, file_name))
return source_files
class SeedChecker:
""" Check if the seeds fit MLPerf submission requirements.
Current requirements are:
1. All seeds must be logged through mllog. Any seed logged via any other
method will be discarded.
2. All seeds, must be valid integers (convertible via int()).
3. We expect all runs to log at least one seed.
4. If one run logs one seed on a certain line in a certain source file, no
other run can log the same seed on the same line in the same file.
Unsatisfying any of the above requirements results in check failure.
A warning is raised for the following situations:
1. Any run logs more than one seed.
"""
def __init__(self, ruleset):
self._ruleset = ruleset
def _get_seed_records(self, result_file):
loglines, errors = mlp_parser.parse_file(
result_file,
ruleset=self._ruleset,
)
if len(errors) > 0:
raise ValueError('\n'.join(
['Found parsing errors:'] +
['{}\n ^^ {}'.format(line, error)
for line, error in errors] +
['', 'Log lines had parsing errors.']))
return [(
line.value['metadata']['file'],
line.value['metadata']['lineno'],
int(line.value['value']),
) for line in loglines if line.key == 'seed']
def _assert_unique_seed_per_run(self, result_files):
no_logged_seed = True
error_messages = []
seed_to_result_file = {}
for result_file in result_files:
try:
seed_records = self._get_seed_records(result_file)
except Exception as e:
error_messages.append("Error found when querying seeds from "
"{}: {}".format(result_file, e))
continue
no_logged_seed = (len(seed_records) <= 0)
if no_logged_seed:
error_messages.append(
"Result file {} logs no seed.".format(result_file)
)
if len(seed_records) > 1:
warnings.warn(
"Result file {} logs more than one seeds {}!".format(
result_file, seed_records))
for f, ln, s in seed_records:
if (f, ln, s) in seed_to_result_file:
error_messages.append(
"Result file {} logs seed {} on {}:{}. However, "
"result file {} already logs the same seed on the same "
"line.".format(
result_file,
s,
f,
ln,
seed_to_result_file[(f, ln, s)],
))
else:
seed_to_result_file[(f, ln, s)] = result_file
return error_messages
def _has_seed_keyword(self, source_file):
with open(source_file, 'r') as file_handle:
for line in file_handle.readlines():
if 'seed' in line.lower():
return True
return False
def check_seeds(self, result_files, seed_checker_bypass = False):
""" Check the seeds for a specific benchmark submission.
Args:
result_files: An iterable contains paths to all the result files for
this benchmark.
"""
_print_divider_bar()
logging.info(" Running Seed Checker")
if seed_checker_bypass:
logging.info("Bypassing Seed Checker")
else:
error_messages = self._assert_unique_seed_per_run(
result_files
)
if len(error_messages) > 0:
logging.error(" Seed checker failed and found the following errors: %s", '\n'.join(error_messages))
#print("Seed checker failed and found the following "
# "errors:\n{}".format('\n'.join(error_messages)))
return False
return True