-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlog.py
149 lines (116 loc) · 4.86 KB
/
log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import csv
import numpy as np
log_reward_path = '/Users/songyanho/Developer/DeepTraffic/log'
log_q_values_path = '/Users/songyanho/Developer/DeepTraffic/log'
class Log:
"""
Base-class for logging data to a text-file during training.
It is possible to use TensorFlow / TensorBoard for this,
but it is quite awkward to implement, as it was intended
for logging variables and other aspects of the TensorFlow graph.
We want to log the reward and Q-values which are not in that graph.
"""
def __init__(self, file_path):
"""Set the path for the log-file. Nothing is saved or loaded yet."""
# Path for the log-file.
self.file_path = file_path
# Data to be read from the log-file by the _read() function.
self.count_episodes = None
self.count_states = None
self.data = None
def _write(self, count_episodes, count_states, msg):
"""
Write a line to the log-file. This is only called by sub-classes.
:param count_episodes:
Counter for the number of episodes processed during training.
:param count_states:
Counter for the number of states processed during training.
:param msg:
Message to write in the log.
"""
with open(file=self.file_path, mode='a', buffering=1) as file:
msg_annotated = "{0}\t{1}\t{2}\n".format(count_episodes, count_states, msg)
file.write(msg_annotated)
def _read(self):
"""
Read the log-file into memory so it can be plotted.
It sets self.count_episodes, self.count_states and self.data
"""
# Open and read the log-file.
with open(self.file_path) as f:
reader = csv.reader(f, delimiter="\t")
self.count_episodes, self.count_states, data = zip(*reader)
# Convert the remaining log-data to a NumPy float-array.
self.data = np.array(data, dtype='float')
class LogReward(Log):
"""Log the rewards obtained for episodes during training."""
def __init__(self):
# These will be set in read() below.
self.episode = None
self.mean = None
# Super-class init.
Log.__init__(self, file_path=log_reward_path)
def write(self, count_episodes, count_states, reward_episode, reward_mean):
"""
Write the episode and mean reward to file.
:param count_episodes:
Counter for the number of episodes processed during training.
:param count_states:
Counter for the number of states processed during training.
:param reward_episode:
Reward for one episode.
:param reward_mean:
Mean reward for the last e.g. 30 episodes.
"""
msg = "{0:.1f}\t{1:.1f}".format(reward_episode, reward_mean)
self._write(count_episodes=count_episodes, count_states=count_states, msg=msg)
def read(self):
"""
Read the log-file into memory so it can be plotted.
It sets self.count_episodes, self.count_states, self.episode and self.mean
"""
# Read the log-file using the super-class.
self._read()
# Get the episode reward.
self.episode = self.data[0]
# Get the mean reward.
self.mean = self.data[1]
class LogQValues(Log):
"""Log the Q-Values during training."""
def __init__(self):
# These will be set in read() below.
self.min = None
self.mean = None
self.max = None
self.std = None
# Super-class init.
Log.__init__(self, file_path=log_q_values_path)
def write(self, count_episodes, count_states, q_values):
"""
Write basic statistics for the Q-values to file.
:param count_episodes:
Counter for the number of episodes processed during training.
:param count_states:
Counter for the number of states processed during training.
:param q_values:
Numpy array with Q-values from the replay-memory.
"""
msg = "{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}".format(np.min(q_values),
np.mean(q_values),
np.max(q_values),
np.std(q_values))
self._write(count_episodes=count_episodes,
count_states=count_states,
msg=msg)
def read(self):
"""
Read the log-file into memory so it can be plotted.
It sets self.count_episodes, self.count_states, self.min / mean / max / std.
"""
# Read the log-file using the super-class.
self._read()
# Get the logged statistics for the Q-values.
self.min = self.data[0]
self.mean = self.data[1]
self.max = self.data[2]
self.std = self.data[3]