-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathlogger.py
99 lines (81 loc) · 4.08 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import random
import time
import datetime
import sys
from torch.autograd import Variable
import torch
from visdom import Visdom
import numpy as np
def tensor2image(tensor):
image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
if image.shape[0] == 1:
image = np.tile(image, (3,1,1))
return image.astype(np.uint8)
class Logger():
def __init__(self, n_epochs, batches_epoch):
self.viz = Visdom()
self.n_epochs = n_epochs
self.batches_epoch = batches_epoch
self.epoch = 1
self.batch = 1
self.prev_time = time.time()
self.mean_period = 0
self.losses = {}
self.metrics = {}
self.loss_windows = {}
self.image_windows = {}
self.metric_windows = {}
def log(self, losses=None, metrics = None, images=None):
self.mean_period += (time.time() - self.prev_time)
self.prev_time = time.time()
sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch))
for i, loss_name in enumerate(losses.keys()):
if loss_name not in self.losses:
self.losses[loss_name] = losses[loss_name].data[0]
else:
self.losses[loss_name] += losses[loss_name].data[0]
if (i+1) == len(losses.keys()):
sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch))
else:
sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch))
for i, metric_name in enumerate(metrics.keys()):
if metric_name not in self.metrics:
self.metrics[metric_name] = metrics[metric_name]
else:
self.metrics[metric_name] += metrics[metric_name]
batches_done = self.batches_epoch*(self.epoch - 1) + self.batch
batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch
sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done)))
sys.stdout.write('--Time in Epoch : %s' % (datetime.timedelta(seconds=self.mean_period)))
# Draw images
'''
for image_name, tensor in images.items():
if image_name not in self.image_windows:
self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name})
else:
self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name})
'''
# End of epoch
if (self.batch % self.batches_epoch) == 0:
# Plot losses
for loss_name, loss in self.losses.items():
if loss_name not in self.loss_windows:
self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]),
opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name})
else:
self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append')
# Reset losses for next epoch
self.losses[loss_name] = 0.0
for metric_name, metric in self.metrics.items():
if metric_name not in self.metric_windows:
self.metric_windows[metric_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([metric/self.batch]),
opts={'xlabel': 'epochs', 'ylabel': metric_name, 'title': metric_name})
else:
self.viz.line(X=np.array([self.epoch]), Y=np.array([metric/self.batch]), win=self.metric_windows[metric_name], update='append')
# Reset losses for next epoch
self.metrics[metric_name] = 0.0
self.epoch += 1
self.batch = 1
sys.stdout.write('\n')
else:
self.batch += 1