-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathae_basic.py
executable file
·166 lines (135 loc) · 5.78 KB
/
ae_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from matplotlib import pyplot as plt
import numpy as np
from autoencoder import AutoEncoder
from stacked_mnist import DataMode, StackedMNISTData
from verification_net import VerificationNet
class AEBasic:
"""
Autoencoder basic task: Reconstructing images
"""
def __init__(self,
latent_dim=4,
three_colors=False,
save_weigths=False,
save_image=False) -> None:
self.autoencoder = AutoEncoder(latent_dim)
self.three_colors = three_colors
self.save_weigths = save_weigths
self.save_image = save_image
self.gen = self.get_generator(self.three_colors)
self.ver_net = VerificationNet()
def get_generator(self, three_colors):
"""
Returning the appropriate generator
"""
# Returning a generator that uses standard MNIST
if three_colors:
return StackedMNISTData(mode=DataMode.COLOR_BINARY_COMPLETE,
default_batch_size=2048)
# Returning a generator that uses stacked MNIST
else:
return StackedMNISTData(mode=DataMode.MONO_BINARY_COMPLETE,
default_batch_size=2048)
def get_train_test(self, gen):
"""
Getting the train and test data
"""
x_train, y_train = gen.get_full_data_set(training=True)
x_test, y_test = gen.get_full_data_set(training=False)
return x_train, y_train, x_test, y_test
def train_autoencoder(self):
"""
Training the autoencoder on single-channel images
"""
x_train, y_train, x_test, y_test = self.get_train_test(self.gen)
# Reshaping
x_train = x_train[:, :, :, [0]]
x_test = x_test[:, :, :, [0]]
# Training the AE
self.autoencoder.train(x_train,
x_train,
batch_size=64,
epochs=20,
shuffle=True,
validation_data=(x_test, x_test),
verbose=True,
save_weights=self.save_weigths)
def run(self):
"""
Reconstructing images and displaying the results
"""
# Training the autoencoder
self.train_autoencoder()
x_train, y_train, x_test, y_test = self.get_train_test(self.gen)
# If we have multi-color images
if self.three_colors:
reconstructed = []
# For each color-channel
for i in range(3):
# Getting the specific color channel
x_test_channel = x_test[:, :, :, [i]]
# Sending the images through the AE to get reconstructed images
encoded_imgs = self.autoencoder.encoder(x_test_channel).numpy()
decoded_imgs = self.autoencoder.decoder(encoded_imgs).numpy()
# For improved accuracy use np.around()
#reconstructed.append(np.around(np.squeeze(decoded_imgs)))
reconstructed.append(np.squeeze(decoded_imgs))
# Combining the different color channel images to one stacked image
reconstructed = np.stack(reconstructed, axis=-1)
# Using VerNet to get predictability and accuracy
pred, acc = self.ver_net.check_predictability(reconstructed,
y_test,
tolerance=0.5)
print("Predictability: " + str(pred) + ", accuracy:" + str(acc))
self.show_figure(10, x_test, reconstructed, y_test, pred, acc)
else:
# Reshaping
x_test = x_test[:, :, :, [0]]
# Sending the images through the AE to get reconstructed images
encoded_imgs = self.autoencoder.encoder(x_test).numpy()
decoded_imgs = self.autoencoder.decoder(encoded_imgs).numpy()
# Using VerNet to get predictability and accuracy
pred, acc = self.ver_net.check_predictability(decoded_imgs, y_test)
print("Predictability: " + str(pred) + ", accuracy:" + str(acc))
self.show_figure(10, x_test, decoded_imgs, y_test, pred, acc)
def show_figure(self, n, original, reconstructed, y_test, predictability,
accuracy):
"""
Plotting original images and their reconstructions
"""
# Showing the original images and reconstructed images
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(original[i].astype(np.float64))
plt.title("Class " + str(y_test[i]))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(reconstructed[i])
plt.title("Reconstruct")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.suptitle("" + str(n) + " images reconstructed" +
" (Predictability: " + str(predictability) +
", Accuracy: " + str(accuracy) + ")",
fontsize="x-large")
# Choosing filepath
if self.three_colors:
path = "./results/ae-basic-color"
else:
path = "./results/ae-basic-mono"
# Save figure
if self.save_image:
plt.savefig(path)
# Show image
plt.show()
if __name__ == "__main__":
ae_basic = AEBasic(three_colors=False,
save_image=False,
save_weigths=False)
ae_basic.run()