-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparsing.py
119 lines (104 loc) · 4.36 KB
/
parsing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
'''
Module for parsing functions
'''
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
from plotting import run_fig_extras
BEST_VAL_MARK = 'Best validation acc:'
LAST_VAL_MARK = 'Last validation acc:'
VAL_LIST_MARK = 'Validation list:'
def parse_validations(filename):
'''
Parse and read the best validation accuracy, last validation accuracy,
and list of validation accuracies (for each epoch) for output file
'''
val_list = []
best_val_acc = -1
last_val_acc = -1
with open(filename, 'r') as f:
for line in f:
if line.startswith(VAL_LIST_MARK):
val_list_str = line.split(':')[1].strip().lstrip('[').rstrip(']')
val_list = [float(acc) for acc in val_list_str.split(',')]
elif line.startswith(BEST_VAL_MARK):
best_val_acc = float(line.split(':')[1].strip())
elif line.startswith(LAST_VAL_MARK):
last_val_acc = float(line.split(':')[1].strip())
return best_val_acc, last_val_acc, val_list
ADD_PARAMS_MSGS = 'Starting exp'
NEXT_EXPERIMENT_MSG = 'Suggestion:'
LR_MESSAGE = 'learning rate: '
def parse_accuracy(line):
try:
val_acc = float(line.split(':')[1].strip())
except ValueError:
if line.split(':')[1].strip().startswith('tensor'):
val_acc = float(line.split(':')[1].strip()[7:-1])
else:
raise ValueError
return val_acc
def parse_validations_table(filename):
'''
Parse and read the best validation accuracy, last validation accuracy,
and list of validation accuracies (for each epoch) for output file
returns
- output: accuracies for every run
- formatted_output: Only the best (early stopped) accuracy per seed
This only works for one fixed size!
'''
output = []
val_list = []
best_val_acc = -1
last_val_acc = -1
seed = -1
hidden_size = -1
learning_rate = -1
with open(filename, 'r') as f:
for line in f:
#if line.startswith(VAL_LIST_MARK):
# val_list_str = line.split(':')[1].strip().lstrip('[').rstrip(']')
# val_list = [float(acc) for acc in val_list_str.split(',')]
if line.startswith(BEST_VAL_MARK):
best_val_acc = parse_accuracy(line)
elif line.startswith(LAST_VAL_MARK):
last_val_acc = parse_accuracy(line)
elif line.startswith(ADD_PARAMS_MSGS):
hidden_size, seed = line.split('size')[1].strip().split('with seed')
hidden_size, seed = float(hidden_size.strip()), float(seed.strip())
elif line.startswith(LR_MESSAGE):
learning_rate = float(line.split(':')[1].strip())
elif line.startswith(NEXT_EXPERIMENT_MSG):
output.append([learning_rate, best_val_acc, last_val_acc, hidden_size, seed])
output = np.array(output)
formatted_output = []
for val in np.unique(output[:, 3]):
# restrict to only networks of specific size
new_out = output[output[:, 3] == val, :]
# collect seeds with highest
sidx = np.lexsort(new_out[:, [2, 4]].T)
idx = np.append(np.flatnonzero(new_out[1:, 4] > new_out[:-1, 4]), new_out.shape[0] - 1)
formatted_output.extend(new_out[sidx[idx]])
return np.array(formatted_output), output, hidden_size
def plot_formatted_output(formatted_output):
for row in formatted_output:
lr, best_val_acc, last_val_acc, hidden_size, _ = row
plt.plot(lr, last_val_acc, '.', label=str(hidden_size))
run_fig_extras(
xlabel='Learning rate',
ylabel='Validation accuracy at end',
title='Hyperparameter Tuning',
filename='plots/tuning_cifar.pdf',
xscale='log'
)
if __name__ == '__main__':
slurm_ids = [202408, 200631]
formatted_outputs = None
for slurm_id in slurm_ids:
print('Parsing {}'.format(slurm_id))
filename = 'slurm-{}.out'.format(slurm_id)
formatted_output, _, _ = parse_validations_table(filename)
formatted_outputs = np.append(
formatted_outputs, formatted_output, axis=0
) if formatted_outputs is not None else formatted_output
plot_formatted_output(formatted_outputs)