-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathnet.py
101 lines (79 loc) · 3.27 KB
/
net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2019 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from torch import nn
from torch.nn import functional as F
class VAE(nn.Module):
def __init__(self, zsize, layer_count=3, channels=3):
super(VAE, self).__init__()
d = 128
self.d = d
self.zsize = zsize
self.layer_count = layer_count
mul = 1
inputs = channels
for i in range(self.layer_count):
setattr(self, "conv%d" % (i + 1), nn.Conv2d(inputs, d * mul, 4, 2, 1))
setattr(self, "conv%d_bn" % (i + 1), nn.BatchNorm2d(d * mul))
inputs = d * mul
mul *= 2
self.d_max = inputs
self.fc1 = nn.Linear(inputs * 4 * 4, zsize)
self.fc2 = nn.Linear(inputs * 4 * 4, zsize)
self.d1 = nn.Linear(zsize, inputs * 4 * 4)
mul = inputs // d // 2
for i in range(1, self.layer_count):
setattr(self, "deconv%d" % (i + 1), nn.ConvTranspose2d(inputs, d * mul, 4, 2, 1))
setattr(self, "deconv%d_bn" % (i + 1), nn.BatchNorm2d(d * mul))
inputs = d * mul
mul //= 2
setattr(self, "deconv%d" % (self.layer_count + 1), nn.ConvTranspose2d(inputs, channels, 4, 2, 1))
def encode(self, x):
for i in range(self.layer_count):
x = F.relu(getattr(self, "conv%d_bn" % (i + 1))(getattr(self, "conv%d" % (i + 1))(x)))
x = x.view(x.shape[0], self.d_max * 4 * 4)
h1 = self.fc1(x)
h2 = self.fc2(x)
return h1, h2
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, x):
x = x.view(x.shape[0], self.zsize)
x = self.d1(x)
x = x.view(x.shape[0], self.d_max, 4, 4)
#x = self.deconv1_bn(x)
x = F.leaky_relu(x, 0.2)
for i in range(1, self.layer_count):
x = F.leaky_relu(getattr(self, "deconv%d_bn" % (i + 1))(getattr(self, "deconv%d" % (i + 1))(x)), 0.2)
x = F.tanh(getattr(self, "deconv%d" % (self.layer_count + 1))(x))
return x
def forward(self, x):
mu, logvar = self.encode(x)
mu = mu.squeeze()
logvar = logvar.squeeze()
z = self.reparameterize(mu, logvar)
return self.decode(z.view(-1, self.zsize, 1, 1)), mu, logvar
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()