-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
338 lines (276 loc) · 11.9 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
'''
@Description:
@version:
@License: MIT
@Author: Wang Yao
@Date: 2020-03-22 17:48:05
@LastEditors: Wang Yao
@LastEditTime: 2020-03-26 18:35:10
'''
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer
class Embedding(Layer):
def __init__(self, vocab_size, model_dim, **kwargs):
self._vocab_size = vocab_size
self._model_dim = model_dim
super(Embedding, self).__init__(**kwargs)
def build(self, input_shape):
self.embeddings = self.add_weight(
shape=(self._vocab_size, self._model_dim),
initializer='glorot_uniform',
name="embeddings")
super(Embedding, self).build(input_shape)
def call(self, inputs):
if K.dtype(inputs) != 'int32':
inputs = K.cast(inputs, 'int32')
embeddings = K.gather(self.embeddings, inputs)
embeddings *= self._model_dim ** 0.5 # Scale
return embeddings
def compute_output_shape(self, input_shape):
return input_shape + (self._model_dim,)
class PositionEncoding(Layer):
def __init__(self, model_dim, **kwargs):
self._model_dim = model_dim
super(PositionEncoding, self).__init__(**kwargs)
def call(self, inputs):
seq_length = inputs.shape[1]
position_encodings = np.zeros((seq_length, self._model_dim))
for pos in range(seq_length):
for i in range(self._model_dim):
position_encodings[pos, i] = pos / \
np.power(10000, (i-i % 2) / self._model_dim)
position_encodings[:, 0::2] = np.sin(position_encodings[:, 0::2]) # 2i
position_encodings[:, 1::2] = np.cos(
position_encodings[:, 1::2]) # 2i+1
position_encodings = K.cast(position_encodings, 'float32')
return position_encodings
def compute_output_shape(self, input_shape):
return input_shape
class ScaledDotProductAttention(Layer):
def __init__(self, masking=True, future=False, dropout_rate=0., **kwargs):
self._masking = masking
self._future = future
self._dropout_rate = dropout_rate
self._masking_num = -2**32+1
super(ScaledDotProductAttention, self).__init__(**kwargs)
def mask(self, inputs, masks):
masks = K.cast(masks, 'float32')
masks = K.tile(masks, [K.shape(inputs)[0] // K.shape(masks)[0], 1])
masks = K.expand_dims(masks, 1)
outputs = inputs + masks * self._masking_num
return outputs
def future_mask(self, inputs):
diag_vals = tf.ones_like(inputs[0, :, :])
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()
future_masks = tf.tile(tf.expand_dims(tril, 0), [
tf.shape(inputs)[0], 1, 1])
paddings = tf.ones_like(future_masks) * self._masking_num
outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
return outputs
def call(self, inputs):
if self._masking:
assert len(
inputs) == 4, "inputs should be set [queries, keys, values, masks]."
queries, keys, values, masks = inputs
else:
assert len(
inputs) == 3, "inputs should be set [queries, keys, values]."
queries, keys, values = inputs
if K.dtype(queries) != 'float32':
queries = K.cast(queries, 'float32')
if K.dtype(keys) != 'float32':
keys = K.cast(keys, 'float32')
if K.dtype(values) != 'float32':
values = K.cast(values, 'float32')
matmul = K.batch_dot(queries, tf.transpose(keys, [0, 2, 1])) # MatMul
scaled_matmul = matmul / int(queries.shape[-1]) ** 0.5 # Scale
if self._masking:
scaled_matmul = self.mask(scaled_matmul, masks) # Mask(opt.)
if self._future:
scaled_matmul = self.future_mask(scaled_matmul)
softmax_out = K.softmax(scaled_matmul) # SoftMax
# Dropout
out = K.dropout(softmax_out, self._dropout_rate)
outputs = K.batch_dot(out, values)
return outputs
def compute_output_shape(self, input_shape):
return input_shape
class MultiHeadAttention(Layer):
def __init__(self, n_heads, head_dim, dropout_rate=.1, masking=True, future=False, trainable=True, **kwargs):
self._n_heads = n_heads
self._head_dim = head_dim
self._dropout_rate = dropout_rate
self._masking = masking
self._future = future
self._trainable = trainable
super(MultiHeadAttention, self).__init__(**kwargs)
def build(self, input_shape):
self._weights_queries = self.add_weight(
shape=(input_shape[0][-1], self._n_heads * self._head_dim),
initializer='glorot_uniform',
trainable=self._trainable,
name='weights_queries')
self._weights_keys = self.add_weight(
shape=(input_shape[1][-1], self._n_heads * self._head_dim),
initializer='glorot_uniform',
trainable=self._trainable,
name='weights_keys')
self._weights_values = self.add_weight(
shape=(input_shape[2][-1], self._n_heads * self._head_dim),
initializer='glorot_uniform',
trainable=self._trainable,
name='weights_values')
super(MultiHeadAttention, self).build(input_shape)
def call(self, inputs):
if self._masking:
assert len(
inputs) == 4, "inputs should be set [queries, keys, values, masks]."
# (bs, 100, 256)
queries, keys, values, masks = inputs
else:
assert len(
inputs) == 3, "inputs should be set [queries, keys, values]."
queries, keys, values = inputs
# (bs, 100, 256)*(256, 512) ==> (bs, 100, 512)
queries_linear = K.dot(queries, self._weights_queries)
keys_linear = K.dot(keys, self._weights_keys)
values_linear = K.dot(values, self._weights_values)
queries_multi_heads = tf.concat(
tf.split(queries_linear, self._n_heads, axis=2), axis=0)
keys_multi_heads = tf.concat(
tf.split(keys_linear, self._n_heads, axis=2), axis=0)
values_multi_heads = tf.concat(
tf.split(values_linear, self._n_heads, axis=2), axis=0)
if self._masking:
att_inputs = [queries_multi_heads,
keys_multi_heads, values_multi_heads, masks]
else:
att_inputs = [queries_multi_heads,
keys_multi_heads, values_multi_heads]
attention = ScaledDotProductAttention(
masking=self._masking, future=self._future, dropout_rate=self._dropout_rate)
# att_out: (bs, 100, 64)
att_out = attention(att_inputs)
outputs = tf.concat(tf.split(att_out, self._n_heads, axis=0), axis=2)
# print('**********************',
# len(tf.split(att_out, self._n_heads, axis=0)))
# print('**********************',
# [e.shape for e in tf.split(att_out, self._n_heads, axis=0)])
return outputs
# return queries_linear
def compute_output_shape(self, input_shape):
return input_shape
class PositionWiseFeedForward(Layer):
def __init__(self, model_dim, inner_dim, trainable=True, **kwargs):
self._model_dim = model_dim
self._inner_dim = inner_dim
self._trainable = trainable
super(PositionWiseFeedForward, self).__init__(**kwargs)
def build(self, input_shape):
self.weights_inner = self.add_weight(
shape=(input_shape[-1], self._inner_dim),
initializer='glorot_uniform',
trainable=self._trainable,
name="weights_inner")
self.weights_out = self.add_weight(
shape=(self._inner_dim, self._model_dim),
initializer='glorot_uniform',
trainable=self._trainable,
name="weights_out")
self.bais_inner = self.add_weight(
shape=(self._inner_dim,),
initializer='uniform',
trainable=self._trainable,
name="bais_inner")
self.bais_out = self.add_weight(
shape=(self._model_dim,),
initializer='uniform',
trainable=self._trainable,
name="bais_out")
super(PositionWiseFeedForward, self).build(input_shape)
def call(self, inputs):
if K.dtype(inputs) != 'float32':
inputs = K.cast(inputs, 'float32')
inner_out = K.relu(K.dot(inputs, self.weights_inner) + self.bais_inner)
outputs = K.dot(inner_out, self.weights_out) + self.bais_out
return outputs
def compute_output_shape(self, input_shape):
return self._model_dim
class LayerNormalization(Layer):
def __init__(self, epsilon=1e-8, **kwargs):
self._epsilon = epsilon
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.beta = self.add_weight(
shape=(input_shape[-1],),
initializer='zero',
name='beta')
self.gamma = self.add_weight(
shape=(input_shape[-1],),
initializer='one',
name='gamma')
super(LayerNormalization, self).build(input_shape)
def call(self, inputs):
mean, variance = tf.nn.moments(inputs, [-1], keepdims=True)
normalized = (inputs - mean) / ((variance + self._epsilon) ** 0.5)
outputs = self.gamma * normalized + self.beta
return outputs
def compute_output_shape(self, input_shape):
return input_shape
class Add(Layer):
def __init__(self, **kwargs):
super(Add, self).__init__(**kwargs)
def call(self, inputs):
input_a, input_b = inputs
return input_a + input_b
def compute_output_shape(self, input_shape):
return input_shape[0]
if __name__ == "__main__":
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, GlobalAveragePooling1D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.utils import to_categorical
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
vocab_size = 5000
max_len = 256
model_dim = 512
batch_size = 128
epochs = 10
print("Data downloading and pre-processing ... ")
(x_train, y_train), (x_test, y_test) = imdb.load_data(
maxlen=max_len, num_words=vocab_size)
x_train = sequence.pad_sequences(x_train, maxlen=max_len)
x_test = sequence.pad_sequences(x_test, maxlen=max_len)
x_train_masks = tf.equal(x_train, 0)
x_test_masks = tf.equal(x_test, 0)
y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
print('Model building ... ')
inputs = Input(shape=(max_len,), name="inputs")
masks = Input(shape=(max_len,), name='masks')
embeddings = Embedding(vocab_size, model_dim)(inputs)
encodings = PositionEncoding(model_dim)(embeddings)
encodings = Add()([embeddings, encodings])
x = MultiHeadAttention(8, 64)([encodings, encodings, encodings, masks])
x = GlobalAveragePooling1D()(x)
x = Dropout(0.2)(x)
x = Dense(10, activation='relu')(x)
outputs = Dense(2, activation='softmax')(x)
model = Model(inputs=[inputs, masks], outputs=outputs)
model.compile(optimizer=Adam(beta_1=0.9, beta_2=0.98, epsilon=1e-9),
loss='categorical_crossentropy', metrics=['accuracy'])
print("Model Training ... ")
es = EarlyStopping(patience=5)
model.fit([x_train, x_train_masks], y_train,
batch_size=batch_size, epochs=epochs, validation_split=0.2, callbacks=[es])
test_metrics = model.evaluate(
[x_test, x_test_masks], y_test, batch_size=batch_size, verbose=0)
print("loss on Test: %.4f" % test_metrics[0])
print("accu on Test: %.4f" % test_metrics[1])