-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_loader.py
49 lines (41 loc) · 1.42 KB
/
plot_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines.results_plotter import load_results, ts2xy
def moving_average(values, window):
"""
Smooth values by doing a moving average
:param values: (numpy array)
:param window: (int)
:return: (numpy array)
"""
weights = np.repeat(1.0, window) / window
return np.convolve(values, weights, 'valid')
def plot_results(log_folder, model_name, plt_dir, title='Learning Curve'):
"""
plot the results
:param log_folder: (str) the save location of the results to plot
:param title: (str) the title of the task to plot
"""
# m_name_csv = model_name + ".csv"
# old_file_name = os.path.join(log_folder, "monitor.csv")
# new_file_name = os.path.join(log_folder, m_name_csv)
save_name = os.path.join(plt_dir, model_name)
x, y = ts2xy(load_results(log_folder), 'timesteps')
# shutil.copy(old_file_name, new_file_name)
y = moving_average(y, window=10)
# Truncate x
x = x[len(x) - len(y):]
fig = plt.figure(title)
plt.plot(x, y)
plt.xlabel('Number of Timesteps')
plt.ylabel('Rewards')
plt.title(title + " Smoothed")
print('Saving plot at:', save_name)
plt.savefig(save_name + ".png")
plt.savefig(save_name + ".eps")
print("plots saved...")
# plt.show()
if __name__ == '__main__':
plot_results(str(sys.argv[1]), str(sys.argv[2]), str(sys.argv[3]))