-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathfast_attention.py
498 lines (431 loc) · 18.2 KB
/
fast_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
"""Implementation of multiheaded FAVOR-attention & FAVOR-self-attention layers.
Prefix Sum Tensorflow implementation by Valerii Likhosherstov.
"""
import math
import numpy as np
import tensorflow as tf
import util
BIG_CONSTANT = 1e8
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
r"""Constructs the matrix of random projections.
Constructs a matrix of random orthogonal projections. Each projection vector
has direction chosen uniformly at random and either deterministic length
\sqrt{d} or length taken from the \chi(d) distribution (in the latter case
marginal distributions of the projections are d-dimensional Gaussian vectors
with associated identity covariance matrix).
Args:
m: number of random projections.
d: dimensionality of each random projection.
seed: random seed used to construct projections.
scaling: 1 if all the random projections need to be renormalized to have
length \sqrt{d}, 0 if the lengths of random projections should follow
\chi(d) distribution.
struct_mode: if True then products of Givens rotations will be used to
construct random orthogonal matrix. This bypasses Gram-Schmidt
orthogonalization.
Returns:
The matrix of random projections of the shape [m, d].
"""
nb_full_blocks = int(m / d)
block_list = []
current_seed = seed
for _ in range(nb_full_blocks):
if struct_mode:
q = create_products_of_givens_rotations(d, seed)
else:
unstructured_block = tf.random.normal((d, d), seed=current_seed)
q, _ = tf.linalg.qr(unstructured_block)
q = tf.transpose(q)
block_list.append(q)
current_seed += 1
remaining_rows = m - nb_full_blocks * d
if remaining_rows > 0:
if struct_mode:
q = create_products_of_givens_rotations(d, seed)
else:
unstructured_block = tf.random.normal((d, d), seed=current_seed)
q, _ = tf.linalg.qr(unstructured_block)
q = tf.transpose(q)
block_list.append(q[0:remaining_rows])
final_matrix = tf.experimental.numpy.vstack(block_list)
current_seed += 1
if scaling == 0:
multiplier = tf.norm(tf.random.normal((m, d), seed=current_seed), axis=1)
elif scaling == 1:
multiplier = tf.math.sqrt(float(d)) * tf.ones((m))
else:
raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling)
return tf.linalg.matmul(tf.linalg.diag(multiplier), final_matrix)
def create_products_of_givens_rotations(dim, seed):
r"""Constructs a 2D-tensor which is a product of Givens random rotations.
Constructs a 2D-tensor of the form G_1 * ... * G_k, where G_i is a Givens
random rotation. The resulting tensor mimics a matrix taken uniformly at
random form the orthogonal group.
Args:
dim: number of rows/columns of the resulting 2D-tensor.
seed: random seed.
Returns:
The product of Givens random rotations.
"""
nb_givens_rotations = dim * int(math.ceil(math.log(float(dim))))
q = np.eye(dim, dim)
np.random.seed(seed)
for _ in range(nb_givens_rotations):
random_angle = math.pi * np.random.uniform()
random_indices = np.random.choice(dim, 2)
index_i = min(random_indices[0], random_indices[1])
index_j = max(random_indices[0], random_indices[1])
slice_i = q[index_i]
slice_j = q[index_j]
new_slice_i = math.cos(random_angle) * slice_i + math.sin(
random_angle) * slice_j
new_slice_j = -math.sin(random_angle) * slice_i + math.cos(
random_angle) * slice_j
q[index_i] = new_slice_i
q[index_j] = new_slice_j
return tf.cast(tf.constant(q), dtype=tf.float32)
def relu_kernel_transformation(data,
is_query,
projection_matrix=None,
numerical_stabilizer=0.001):
"""Computes features for the ReLU-kernel.
Computes random features for the ReLU kernel from
https://arxiv.org/pdf/2009.14794.pdf.
Args:
data: input data tensor of the shape [B, L, H, D], where: B - batch
dimension, L - attention dimensions, H - heads, D - features.
is_query: indicates whether input data is a query oor key tensor.
projection_matrix: random Gaussian matrix of shape [M, D], where M stands
for the number of random features and each D x D sub-block has pairwise
orthogonal rows.
numerical_stabilizer: small positive constant for numerical stability.
Returns:
Corresponding kernel feature map.
"""
del is_query
if projection_matrix is None:
return tf.nn.relu(data) + numerical_stabilizer
else:
ratio = 1.0 / tf.math.sqrt(
tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
data_dash = ratio * tf.einsum("blhd,md->blhm", data, projection_matrix)
return tf.nn.relu(data_dash) + numerical_stabilizer
def softmax_kernel_transformation(data,
is_query,
projection_matrix=None,
numerical_stabilizer=0.000001):
"""Computes random features for the softmax kernel using FAVOR+ mechanism.
Computes random features for the softmax kernel using FAVOR+ mechanism from
https://arxiv.org/pdf/2009.14794.pdf.
Args:
data: input data tensor of the shape [B, L, H, D], where: B - batch
dimension, L - attention dimensions, H - heads, D - features.
is_query: indicates whether input data is a query oor key tensor.
projection_matrix: random Gaussian matrix of shape [M, D], where M stands
for the number of random features and each D x D sub-block has pairwise
orthogonal rows.
numerical_stabilizer: small positive constant for numerical stability.
Returns:
Corresponding kernel feature map.
"""
data_normalizer = 1.0 / (
tf.math.sqrt(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], tf.float32))))
data = data_normalizer * data
ratio = 1.0 / tf.math.sqrt(
tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix)
diag_data = tf.math.square(data)
diag_data = tf.math.reduce_sum(
diag_data, axis=tf.keras.backend.ndim(data) - 1)
diag_data = diag_data / 2.0
diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
last_dims_t = (len(data_dash.shape) - 1,)
attention_dims_t = (len(data_dash.shape) - 3,)
if is_query:
data_dash = ratio * (
tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
data_dash, axis=last_dims_t, keepdims=True)) + numerical_stabilizer)
else:
data_dash = ratio * (
tf.math.exp(data_dash - diag_data - tf.math.reduce_max(
data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
numerical_stabilizer)
return data_dash
def noncausal_numerator(qs, ks, vs):
"""Computes not-normalized FAVOR noncausal attention AV.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
vs: value tensor of the shape [L,B,H,D].
Returns:
Not-normalized FAVOR noncausal attention AV.
"""
kvs = tf.einsum("lbhm,lbhd->bhmd", ks, vs)
return tf.einsum("lbhm,bhmd->lbhd", qs, kvs)
def noncausal_denominator(qs, ks):
"""Computes FAVOR normalizer in noncausal attention.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
Returns:
FAVOR normalizer in noncausal attention.
"""
all_ones = tf.ones([ks.shape[0]])
ks_sum = tf.einsum("lbhm,l->bhm", ks, all_ones)
return tf.einsum("lbhm,bhm->lbh", qs, ks_sum)
def causal_attention_mask(nd, ns, dtype):
"""
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]),
-1, ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:, None]
j = tf.range(ns)
m = i >= j - ns + nd
return tf.cast(m, dtype)
@tf.custom_gradient
def causal_numerator(qs, ks, vs):
"""Computes not-normalized FAVOR causal attention A_{masked}V.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
vs: value tensor of the shape [L,B,H,D].
Returns:
Not-normalized FAVOR causal attention A_{masked}V.
"""
result = []
sums = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))
for index in range(qs.shape[0]):
sums = sums + tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])
result.append(tf.einsum("ijkl,ijk->ijl", sums, qs[index])[None, Ellipsis])
result = tf.concat(result, axis=0)
def grad(res_grad):
grads = tf.zeros_like(tf.einsum("ijk,ijl->ijkl", ks[0], vs[0]))
gr_sums = sums
q_grads = []
k_grads = []
v_grads = []
for index in range(qs.shape[0] - 1, -1, -1):
q_grads.append(
tf.einsum("ijkl,ijl->ijk", gr_sums, res_grad[index])[None, Ellipsis])
grads = grads + tf.einsum("ijk,ijl->ijkl", qs[index], res_grad[index])
k_grads.append(tf.einsum("ijkl,ijl->ijk", grads, vs[index])[None, Ellipsis])
v_grads.append(tf.einsum("ijkl,ijk->ijl", grads, ks[index])[None, Ellipsis])
gr_sums = gr_sums - tf.einsum("ijk,ijl->ijkl", ks[index], vs[index])
q_grads = tf.concat(q_grads[::-1], axis=0)
k_grads = tf.concat(k_grads[::-1], axis=0)
v_grads = tf.concat(v_grads[::-1], axis=0)
return q_grads, k_grads, v_grads
return result, grad
@tf.custom_gradient
def causal_denominator(qs, ks):
"""Computes FAVOR normalizer in causal attention.
Args:
qs: query_prime tensor of the shape [L,B,H,M].
ks: key_prime tensor of the shape [L,B,H,M].
Returns:
FAVOR normalizer in causal attention.
"""
result = []
sums = tf.zeros_like(ks[0])
for index in range(qs.shape[0]):
sums = sums + ks[index]
result.append(tf.reduce_sum(qs[index] * sums, axis=2)[None, Ellipsis])
result = tf.concat(result, axis=0)
def grad(res_grad):
k_grad = tf.zeros_like(ks[0])
gr_sums = sums
q_grads = []
k_grads = []
for index in range(qs.shape[0] - 1, -1, -1):
q_grads.append(
tf.einsum("ijk,ij->ijk", gr_sums, res_grad[index])[None, Ellipsis])
k_grad = k_grad + tf.einsum("ijk,ij->ijk", qs[index], res_grad[index])
k_grads.append(k_grad[None, Ellipsis])
gr_sums = gr_sums - ks[index]
q_grads = tf.concat(q_grads[::-1], axis=0)
k_grads = tf.concat(k_grads[::-1], axis=0)
return q_grads, k_grads
return result, grad
def favor_attention(query,
key,
value,
kernel_transformation,
causal,
projection_matrix=None):
"""Computes FAVOR normalized attention.
Args:
query: query tensor.
key: key tensor.
value: value tensor.
kernel_transformation: transformation used to get finite kernel features.
causal: whether attention is causal or not.
projection_matrix: projection matrix to be used.
Returns:
FAVOR normalized attention.
"""
query_prime = kernel_transformation(query, True,
projection_matrix) # [B,L,H,M]
key_prime = kernel_transformation(key, False, projection_matrix) # [B,L,H,M]
query_prime = tf.transpose(query_prime, [1, 0, 2, 3]) # [L,B,H,M]
key_prime = tf.transpose(key_prime, [1, 0, 2, 3]) # [L,B,H,M]
value = tf.transpose(value, [1, 0, 2, 3]) # [L,B,H,D]
if causal:
av_attention = causal_numerator(query_prime, key_prime, value)
attention_normalizer = causal_denominator(query_prime, key_prime)
else:
av_attention = noncausal_numerator(query_prime, key_prime, value)
attention_normalizer = noncausal_denominator(query_prime, key_prime)
# TODO(kchoro): Add more comments.
av_attention = tf.transpose(av_attention, [1, 0, 2, 3])
attention_normalizer = tf.transpose(attention_normalizer, [1, 0, 2])
attention_normalizer = tf.expand_dims(attention_normalizer,
len(attention_normalizer.shape))
return av_attention / attention_normalizer
class Attention(tf.keras.layers.Layer):
"""Multi-headed attention layer."""
def __init__(self,
hidden_size,
num_heads,
attention_dropout,
kernel_transformation=relu_kernel_transformation,
numerical_stabilizer=0.001,
causal=False,
projection_matrix_type=None,
nb_random_features=0, **kwargs):
"""Initialize Attention.
Args:
hidden_size: int, output dim of hidden layer.
num_heads: int, number of heads to repeat the same attention structure.
attention_dropout: float, dropout rate inside attention for training.
kernel_transformation: transformation used to produce kernel features for
attention.
numerical_stabilizer: used to bound away from zero kernel values.
causal: whether attention is causal or not.
projection_matrix_type: None if Identity should be used, otherwise random
projection matrix will be applied.
nb_random_features: number of random features to be used (relevant only if
projection_matrix is not None).
"""
if hidden_size % num_heads:
raise ValueError(
"Hidden size ({}) must be divisible by the number of heads ({})."
.format(hidden_size, num_heads))
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.kernel_transformation = kernel_transformation
self.numerical_stabilizer = numerical_stabilizer
self.causal = causal
self.projection_matrix_type = projection_matrix_type
self.nb_random_features = nb_random_features
# def build(self, input_shape):
# """Builds the layer."""
# # Layers for linearly projecting the queries, keys, and values.
size_per_head = self.hidden_size // self.num_heads
def _glorot_initializer(fan_in, fan_out):
limit = math.sqrt(6.0 / (fan_in + fan_out))
return tf.keras.initializers.RandomUniform(minval=-limit, maxval=limit)
attention_initializer = _glorot_initializer(hidden_size, self.hidden_size)
self.query_dense_layer = util.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=True,
name="query")
self.key_dense_layer = util.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=True,
name="key")
self.value_dense_layer = util.DenseEinsum(
output_shape=(self.num_heads, size_per_head),
kernel_initializer=attention_initializer,
use_bias=True,
name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = util.DenseEinsum(
output_shape=self.hidden_size,
num_summed_dimensions=2,
kernel_initializer=output_initializer,
use_bias=True,
name="output_transform")
# super(Attention, self).build(input_shape)
def get_config(self):
return {
"hidden_size": self.hidden_size,
"num_heads": self.num_heads,
"attention_dropout": self.attention_dropout,
}
def call(self,
query_input,
source_input,
bias=None,
training=None,
cache=None,
decode_loop_step=None):
"""Apply attention mechanism to query_input and source_input.
Args:
query_input: A tensor with shape [batch_size, length_query, hidden_size].
source_input: A tensor with shape [batch_size, length_source,
hidden_size].
bias: A tensor with shape [batch_size, 1, length_query, length_source],
the attention bias that will be added to the result of the dot product.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, heads, dim_per_head]} where
i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
Returns:
Attention layer output with shape [batch_size, length_query, hidden_size]
"""
# Linearly project the query, key and value using different learned
# projections. Splitting heads is automatically done during the linear
# projections --> [batch_size, length, num_heads, dim_per_head].
query = self.query_dense_layer(query_input)
key = self.key_dense_layer(source_input)
value = self.value_dense_layer(source_input)
if self.projection_matrix_type is None:
projection_matrix = None
else:
dim = query.shape[-1]
seed = 0
projection_matrix = create_projection_matrix(
self.nb_random_features, dim, seed=seed)
if cache is not None:
# Combine cached keys and values with new keys and values.
if decode_loop_step is not None:
cache_k_shape = cache["k"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
[1, cache_k_shape[1], 1, 1])
key = cache["k"] + key * indices
cache_v_shape = cache["v"].shape.as_list()
indices = tf.reshape(
tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
[1, cache_v_shape[1], 1, 1])
value = cache["v"] + value * indices
else:
key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
# Update cache
cache["k"] = key
cache["v"] = value
attention_output = favor_attention(query, key, value,
self.kernel_transformation, self.causal,
projection_matrix)
attention_output = self.output_dense_layer(attention_output)
return attention_output
class SelfAttention(Attention):
"""Multiheaded self-attention layer."""
def call(self,
query_input,
bias=None,
training=None,
cache=None,
decode_loop_step=None):
return super(SelfAttention, self).call(query_input, query_input, bias,
training, cache, decode_loop_step)