-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvisualizer.py
103 lines (84 loc) · 3.86 KB
/
visualizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from itertools import cycle
import numpy
from matplotlib import pyplot
from skimage import filters
class GANDemoVisualizer:
def __init__(self, title, l_kde=100, bw_kde=5):
self.title = title
self.l_kde = l_kde
self.resolution = 1. / self.l_kde
self.bw_kde_ = bw_kde
def draw(self, real_samples, gen_samples, msg=None, cmap='Blues', pause_time=0.05, max_sample_size=500, show=True):
self.fig, self.axes = pyplot.subplots(ncols=3, figsize=(13.5, 4))
self.fig.canvas.set_window_title(self.title)
if msg:
self.fig.suptitle(msg)
ax0, ax1, ax2 = self.axes
self.draw_samples(ax0, 'real and generated samples', real_samples, gen_samples, max_sample_size)
self.draw_density_estimation(ax1, 'density: real samples', real_samples, cmap)
self.draw_density_estimation(ax2, 'density: generated samples', gen_samples, cmap)
if show:
# pyplot.draw()
pyplot.show()
# pyplot.pause(pause_time)
pyplot.close()
@staticmethod
def draw_samples(axis, title, real_samples, generated_samples, max_sample_size):
axis.clear()
axis.set_xlabel(title)
axis.plot(generated_samples[:max_sample_size, 0], generated_samples[:max_sample_size, 1], '.')
axis.plot(real_samples[:max_sample_size, 0], real_samples[:max_sample_size, 1], 'kx')
axis.axis('equal')
axis.axis([0, 1, 0, 1])
def draw_density_estimation(self, axis, title, samples, cmap):
axis.clear()
axis.set_xlabel(title)
density_estimation = numpy.zeros((self.l_kde, self.l_kde))
for x, y in samples:
if 0 < x < 1 and 0 < y < 1:
density_estimation[int((1-y) / self.resolution)][int(x / self.resolution)] += 1
density_estimation = filters.gaussian_filter(density_estimation, self.bw_kde_)
axis.imshow(density_estimation, cmap=cmap)
axis.xaxis.set_major_locator(pyplot.NullLocator())
axis.yaxis.set_major_locator(pyplot.NullLocator())
def savefig(self, filepath):
self.fig.savefig(filepath)
@staticmethod
def show():
pyplot.show()
class CGANDemoVisualizer(GANDemoVisualizer):
def __init__(self, title, l_kde=100, bw_kde=5):
GANDemoVisualizer.__init__(self, title, l_kde, bw_kde)
def draw(self, real_samples, gen_samples, msg=None, cmap='hot', pause_time=0.05, max_sample_size=500, show=True):
if msg:
self.fig.suptitle(msg)
ax0, ax1, ax2 = self.axes
self.draw_samples(ax0, 'real and generated samples', real_samples, gen_samples, max_sample_size)
self.draw_density_estimation(ax1, 'density: real samples', real_samples[:, -2:], cmap)
self.draw_density_estimation(ax2, 'density: generated samples', gen_samples[:, -2:], cmap)
if show:
pyplot.draw()
pyplot.pause(pause_time)
def draw_samples(self, axis, title, real_samples, generated_samples, max_sample_size):
axis.clear()
axis.set_xlabel(title)
g_samples = numpy.copy(generated_samples)
r_samples = numpy.copy(real_samples)
numpy.random.shuffle(g_samples)
numpy.random.shuffle(r_samples)
g_samples = g_samples[:max_sample_size, :]
r_samples = r_samples[:max_sample_size, :]
color_iter = cycle('bgrcmy')
for i in range(g_samples.shape[1]-2):
c = next(color_iter)
samples = g_samples[g_samples[:, i] > 0, :][:, -2:]
axis.plot(samples[:, 0], samples[:, 1], c+'.', markersize=5)
samples = r_samples[r_samples[:, i] > 0, :][:, -2:]
axis.plot(samples[:, 0], samples[:, 1], c+'x', markersize=5)
axis.axis('equal')
axis.axis([0, 1, 0, 1])
def savefig(self, filepath):
self.fig.savefig(filepath)
@staticmethod
def show():
pyplot.show()