This repository has been archived by the owner on Sep 13, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathplot_stats.py
66 lines (54 loc) · 2.44 KB
/
plot_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
def plot_stats(exp_dir, save):
"""
Plots the training progress graph.
Args:
exp_dir (string): Folder containing the pickle files for both the training statistics and hyperparameter settings.
save (bool): Whether to save the graph automatically or not.
"""
with open(os.path.join(exp_dir, 'stats.pickle'), 'rb') as handle:
stats = pickle.load(handle)
handle.close()
with open(os.path.join(exp_dir, 'hyperparams.pickle'), 'rb') as handle:
hyperparams = pickle.load(handle)
handle.close()
# prints the experiment's hyperparameter settings
print("Experiment settings:\n{}".format(hyperparams))
# load stats from saved pickle files
epoch_interval = hyperparams['eval_every']
train_g_loss = [i[0] * 100 for i in stats['train_loss']]
train_d_loss_real = [i[1] for i in stats['train_loss']]
train_d_loss_fake = [i[2] for i in stats['train_loss']]
val_g_loss = [i[0] * 100 for i in stats['val_loss']]
val_d_loss_real = [i[1] for i in stats['val_loss']]
val_d_loss_fake = [i[2] for i in stats['val_loss']]
length = len(stats['train_loss'])
epochs = [epoch_interval * i + 1 for i in range(length)]
print(val_g_loss)
# plot stats
_, axes = plt.subplots(1, 2)
axes[0].set_title('Train Loss vs Epoch')
axes[1].set_title('Val Loss vs Epoch')
axes[0].plot(epochs, train_g_loss, label='Custom RRIN (x100)')
axes[0].plot(epochs, train_d_loss_real, label='D - real')
axes[0].plot(epochs, train_d_loss_fake, label='D - fake')
axes[1].plot(epochs, val_g_loss, label='Custom RRIN (x100)')
axes[1].plot(epochs, val_d_loss_real, label='D - real')
axes[1].plot(epochs, val_d_loss_fake, label='D - fake')
axes[0].set_xticks(np.arange(5,length,5))
axes[1].set_xticks(np.arange(5,length,5))
axes[0].legend()
axes[1].legend()
if save:
plt.savefig('./plot_{}_{}.png'.format(hyperparams['lr'], hyperparams['batch_size']))
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp_dir', required=True, help="path to any experiment folder containing the 2 pickle files from the training process")
parser.add_argument('--save', action='store_true', help="specify if you want to autosave the graph")
args = parser.parse_args()
plot_stats(args.exp_dir, args.save)