forked from atriumlts/subpixel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkeras_subpixel.py
140 lines (119 loc) · 5.49 KB
/
keras_subpixel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from keras import backend as K
from keras.layers import Conv2D
"""
Subpixel Layer as a child class of Conv2D. This layer accepts all normal
arguments, with the exception of dilation_rate(). The argument r indicates
the upsampling factor, which is applied to the normal output of Conv2D.
The output of this layer will have the same number of channels as the
indicated filter field, and thus works for grayscale, color, or as a a
hidden layer.
Arguments:
*see Keras Docs for Conv2D args, noting that dilation_rate() is removed*
r: upscaling factor, which is applied to the output of normal Conv2D
A test is included, which performs super-resolution on the Cifar10 dataset.
Since these images are small, only a scale factor of 2 is used. Test images
are saved in the directory 'test_output/'. This test runs for 5 epochs,
which can be altered in line 132. You can run this test by using the
following commands:
mkdir test_output
python keras_subpixel.py
"""
class Subpixel(Conv2D):
def __init__(self,
filters,
kernel_size,
r,
padding='valid',
data_format=None,
strides=(1,1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Subpixel, self).__init__(
filters=r*r*filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
self.r = r
def _phase_shift(self, I):
r = self.r
bsize, a, b, c = I.get_shape().as_list()
bsize = K.shape(I)[0] # Handling Dimension(None) type for undefined batch dim
X = K.reshape(I, [bsize, a, b, c/(r*r),r, r]) # bsize, a, b, c/(r*r), r, r
X = K.permute_dimensions(X, (0, 1, 2, 5, 4, 3)) # bsize, a, b, r, r, c/(r*r)
#Keras backend does not support tf.split, so in future versions this could be nicer
X = [X[:,i,:,:,:,:] for i in range(a)] # a, [bsize, b, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, b, a*r, r, c/(r*r)
X = [X[:,i,:,:,:] for i in range(b)] # b, [bsize, r, r, c/(r*r)
X = K.concatenate(X, 2) # bsize, a*r, b*r, c/(r*r)
return X
def call(self, inputs):
return self._phase_shift(super(Subpixel, self).call(inputs))
def compute_output_shape(self, input_shape):
unshifted = super(Subpixel, self).compute_output_shape(input_shape)
return (unshifted[0], self.r*unshifted[1], self.r*unshifted[2], unshifted[3]/(self.r*self.r))
def get_config(self):
config = super(Conv2D, self).get_config()
config.pop('rank')
config.pop('dilation_rate')
config['filters']/=self.r*self.r
config['r'] = self.r
return config
if __name__ == "__main__":
import keras
from keras.models import Model
from keras.layers import Input, Conv2D, Activation
from keras.optimizers import Adam
from keras.losses import mean_squared_error
from keras.datasets import cifar10
import skimage.io
from skimage.transform import pyramid_reduce
import numpy as np
#Downloading dataset, downscaling, padding
(HDimages, ignore), (test_images, ignore) = cifar10.load_data()
downscaled=np.zeros((HDimages.shape[0], 16, 16, 3))
downscaled_test = np.zeros((test_images.shape[0], 16, 16, 3))
for i, image in enumerate(HDimages):
downscaled[i,:,:,:] = pyramid_reduce(image, 2)
for i, image in enumerate(test_images):
downscaled_test[i,:,:,:] = pyramid_reduce(image, 2)
pad = 3
padded = np.zeros((downscaled.shape[0], downscaled.shape[1]+2*pad, downscaled.shape[2]+2*pad, downscaled.shape[3]))
padded_test = np.zeros((downscaled_test.shape[0], downscaled_test.shape[1]+2*pad, downscaled_test.shape[2]+2*pad, downscaled_test.shape[3]))
for i, image in enumerate(downscaled):
padded[i, pad:-1*(pad), pad:-1*(pad), :] = image
for i, image in enumerate(downscaled_test):
padded_test[i, pad:-1*(pad), pad:-1*(pad), :] = image
#Nework architecture, including 2x subpixel layer at end
shape = padded.shape
inputs=Input(shape[1:4])
x = Conv2D(32, (3,3), activation='relu')(inputs)
x = Conv2D(32, (3,3), activation='relu')(x)
x = Subpixel(3, (3,3), 2, activation='relu')(x)
model = Model(inputs=inputs, outputs=x)
model.compile(optimizer = Adam(), loss = mean_squared_error)
model.fit(padded, HDimages, epochs=5, validation_split=.2)
result = model.predict(padded_test)
print "Upscaled from (%d, %d) to (%d, %d)"%(downscaled.shape[1], downscaled.shape[2], result.shape[1], result.shape[2])
for i in range(100):
skimage.io.imsave("test_output/%04d_downscaled.jpeg"%(i), downscaled_test[i])
skimage.io.imsave("test_output/%04d_prediction.jpeg"%(i), result[i].astype("uint8"))
skimage.io.imsave("test_output/%04d_original.jpeg"%(i), test_images[i].astype("uint8"))