-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
49 lines (41 loc) · 1.62 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
from mask_conv import MaskConv2D
from residual_block import ResidualBlock
class PixelCNN(tf.keras.Model):
def __init__(self,
input_shape=(32, 32, 3),
hidden_size=120,
n_residual_blocks=10,
color_conditioning=True,
n_mixtures=10):
super(PixelCNN, self).__init__()
self.convA = MaskConv2D(mask_type='A',
color_conditioning=color_conditioning,
filters=2 * hidden_size,
kernel_size=(7, 7),
activation='relu')
self.res_blocks = [
ResidualBlock(hidden_size, color_conditioning)
for _ in range(n_residual_blocks)
]
self.convB_1 = MaskConv2D(mask_type='B',
color_conditioning=color_conditioning,
filters=4 * hidden_size,
kernel_size=(1, 1),
activation=None)
self.convB_2 = MaskConv2D(
mask_type='B',
color_conditioning=color_conditioning,
filters=input_shape[-1] * 3 *
n_mixtures, # RGB * params for mixture of logistics ( pi_i, mu_i, s_i ) * n_mixtures
kernel_size=(1, 1),
activation=None)
def call(self, inputs):
x = self.convA(inputs)
for res in self.res_blocks:
x = res(x)
x = tf.nn.relu(x)
x = self.convB_1(x)
x = tf.nn.relu(x)
x = self.convB_2(x)
return x