-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathttt.py
1649 lines (1409 loc) · 68.8 KB
/
ttt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils._pytree import tree_map
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, logging
from transformers.utils.import_utils import is_causal_conv1d_available
if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
logger = logging.get_logger(__name__)
TTT_STANDARD_CONFIGS = {
"125m": {
"hidden_size": 768,
"intermediate_size": 2048,
"num_hidden_layers": 12,
"num_attention_heads": 12,
},
"350m": {
"hidden_size": 1024,
"intermediate_size": 2736,
"num_hidden_layers": 24,
"num_attention_heads": 16,
},
"760m": {
"hidden_size": 1536,
"intermediate_size": 4096,
"num_hidden_layers": 24,
"num_attention_heads": 16,
},
"1b": {
"hidden_size": 2048,
"intermediate_size": 5504,
"num_hidden_layers": 24,
"num_attention_heads": 32,
},
}
class TTTConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`TTTModel`]. It is used to instantiate an TTT
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the TTT-1B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
use_gate (`bool`, *optional*, defaults to `False`): whether use gating in Mamba backbone
share_qk (`bool`, *optional*, defaults to `False`): whether share Q/K projection matrix
ttt_layer_type (`str`, *optional*, defaults to `"linear"`): ttt block type, "linear" or "mlp", stands for TTT-Linear and TTT-MLP
ttt_base_lr (`float`, *optional*, defaults to 1.0): base learning rate for TTT learner
pre_conv (`bool`, *optional*, defaults to `False`): whether use conv before TTT
conv_kernel (`int`, *optional*, defaults to 4): kernel size of the conv layer
scan_checkpoint_group_size (`int`, *optional*, defaults to 0):
gradient checkpoint group size on seq dimension, 0 means no checkpointing.
In JAX implementation, we set it 4, which means we group 4 mini-batches together in 1 gradient checkpointg to save memory.
```python
>>> from . import TTTModel, TTTConfig
>>> # Initializing a TTT ttt-1b style configuration
>>> configuration = TTTConfig()
>>> # Initializing a model from the ttt-1b style configuration
>>> model = TTTModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "ttt"
def __init__(
self,
vocab_size=32000,
hidden_size=2048,
intermediate_size=5504,
num_hidden_layers=24,
num_attention_heads=32,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=False,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=True,
rope_theta=10000.0,
use_gate=False,
share_qk=False,
ttt_layer_type="linear",
ttt_base_lr=1.0,
mini_batch_size=16,
pre_conv=False,
conv_kernel=4,
scan_checkpoint_group_size=0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.use_gate = use_gate
self.share_qk = share_qk
self.ttt_layer_type = ttt_layer_type
self.ttt_base_lr = ttt_base_lr
self.mini_batch_size = mini_batch_size
self.pre_conv = pre_conv
self.conv_kernel = conv_kernel
self.scan_checkpoint_group_size = scan_checkpoint_group_size
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
########################
### Backbone Modules ###
########################
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def permute_qk(q, k):
# NOTE: EasyLM and transformers use different method to compute rotary emebdding
# we manually reorder the dim here to match our JAX implementation
# which may not be optimal for speed
# reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33
bsz, num_head, seq_len, head_dim = q.shape
q = q.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
k = k.reshape(bsz, num_head, seq_len, head_dim // 2, 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
return q, k
def undo_permute_qk(q, k):
# NOTE: EasyLM and transformers use different method to compute rotary emebdding
# we manually undo the reorder the dim here to match our JAX implementation
# which may not be optimal for speed
# reference: https://github.com/young-geng/EasyLM/blob/981a2ed9630f44258a94b6f44dff2b7bd203ae8d/EasyLM/models/llama/convert_hf_to_easylm.py#L33
bsz, num_head, seq_len, head_dim = q.shape
q = q.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
k = k.reshape(bsz, num_head, seq_len, 2, head_dim // 2).transpose(3, 4).reshape(bsz, num_head, seq_len, head_dim)
return q, k
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class SwiGluMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
up_proj = torch.cat(
[F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)],
dim=-1,
)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=16,
base=10000,
device=None,
scaling_factor=1.0,
):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Conv(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
bias=True,
kernel_size=config.conv_kernel,
groups=config.hidden_size,
padding=config.conv_kernel - 1,
)
def __call__(self, hidden_states, cache_params=None):
seq_len = hidden_states.shape[1]
hidden_states = self.norm(hidden_states)
# [B, C, L]
hidden_states = hidden_states.transpose(1, 2)
if causal_conv1d_fn is None:
if cache_params is not None:
if cache_params.seqlen_offset > 0:
conv_state = cache_params.conv_states_dic["pre_conv"][self.layer_idx]
conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
conv_state[:, :, -1] = hidden_states[:, :, 0]
cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
hidden_states = torch.sum(conv_state * self.conv.weight[:, 0, :], dim=-1)
hidden_states += self.conv.bias
hidden_states = hidden_states.unsqueeze(-1)
else:
conv_state = nn.functional.pad(
hidden_states,
(self.config.conv_kernel - hidden_states.shape[-1], 0),
)
cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_state)
hidden_states = self.conv(hidden_states)[..., :seq_len]
else:
hidden_states = self.conv(hidden_states)[..., :seq_len]
else:
conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2))
if cache_params is not None and cache_params.seqlen_offset > 0:
hidden_states = causal_conv1d_update(
hidden_states.squeeze(-1),
cache_params.conv_states_dic["pre_conv"][self.layer_idx],
conv_weights,
self.conv.bias,
None,
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states,
(self.config.conv_kernel - hidden_states.shape[-1], 0),
)
cache_params.conv_states_dic["pre_conv"][self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv.bias, activation=None)
# [B, L, C]
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
#########################
### TTT Layer Modules ###
#########################
def scan(f, init, xs, out, checkpoint_group=0):
"""Minic jax.lax.scan function."""
carry = init
if isinstance(xs, dict):
num_items = len(next(iter(xs.values())))
else:
num_items = len(xs[0])
def scan_fn(carry, i_start, i_end):
for i in range(i_start, i_end):
if isinstance(xs, dict):
x = {key: tensor[i] for key, tensor in xs.items()}
else:
x = [x[i] for x in xs]
carry, y = f(carry, x)
out[i] = y
return carry
if checkpoint_group > 0:
ckpt_every_n = num_items // checkpoint_group
for k in range(0, num_items, ckpt_every_n):
carry = torch.utils.checkpoint.checkpoint(
scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
)
else:
carry = scan_fn(carry, 0, num_items)
return carry, out
def ln_fwd(x, gamma, beta, eps=1e-6):
"Batch forward for LayerNorm."
# Mean and variance computation
mu = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalization
std = torch.sqrt(var + eps)
x_hat = (x - mu) / std
# Scale and shift
y = gamma * x_hat + beta
return y
def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
"Batch backward for LayerNorm fused with L2 loss."
D = x.shape[-1]
# Mean and variance computation
mu = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# Normalization
std = torch.sqrt(var + eps)
x_hat = (x - mu) / std
# Scale and shift
y = gamma * x_hat + beta
grad_output = y - l2_target
grad_x_hat = grad_output * gamma
z = (
(1.0 / D)
* (
D * grad_x_hat
- grad_x_hat.sum(dim=-1, keepdim=True)
- x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True)
)
/ std
)
return z
# Modified from https://github.com/NVIDIA/Megatron-LM/blob/e33c8f78a35765d5aa37475a144da60e8a2349d1/megatron/core/fusions/fused_bias_gelu.py#L26
def gelu_bwd(x):
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff
class TTTCache:
"""
TTTCache is a data structure that holds the last hidden states and gradients for the TTT layer.
Arguments:
model: TTTModel
batch_size: int
Attributes:
seqlen_offset: int
mini_batch_size: int
params_dict: Dict[str, Dict[int, torch.Tensor]] *_states, *_grad -> # layer_idx -> [batch_size, ...]
conv_states_dic: Dict[str, Dict[int, torch.Tensor]] *_states -> # layer_idx -> [batch_size, ...]
"""
def __init__(self, model, batch_size: int):
config = model.config
self.seqlen_offset = 0
self.mini_batch_size = config.mini_batch_size
self.ttt_params_dict = defaultdict(dict)
if "linear" in config.ttt_layer_type:
self.ttt_param_names = ["W1", "b1"]
elif "mlp" in config.ttt_layer_type:
self.ttt_param_names = ["W1", "b1", "W2", "b2"]
else:
raise ValueError(f"TTT Layer Type {config.ttt_layer_type} not supported yet")
self.conv_states_dic = defaultdict(dict)
logger.info(f"Creating cache of size: {batch_size}")
for layer_idx in range(config.num_hidden_layers):
for name in self.ttt_param_names:
weight = getattr(model.layers[layer_idx].seq_modeling_block, name)
tiled_weight = torch.tile(weight.unsqueeze(0), (batch_size,) + (1,) * weight.dim()).to(model.device)
self.ttt_params_dict[f"{name}_states"][layer_idx] = tiled_weight
# for decoding, we need to store the gradients as well
self.ttt_params_dict[f"{name}_grad"][layer_idx] = torch.zeros_like(tiled_weight)
if config.pre_conv:
self.conv_states_dic["pre_conv"][layer_idx] = torch.zeros(
batch_size,
config.hidden_size,
config.conv_kernel,
device=model.device,
)
if config.share_qk:
self.conv_states_dic["ttt_conv_q"][layer_idx] = torch.zeros(
batch_size,
config.hidden_size,
config.conv_kernel,
device=model.device,
)
self.conv_states_dic["ttt_conv_k"][layer_idx] = torch.zeros(
batch_size,
config.hidden_size,
config.conv_kernel,
device=model.device,
)
def update(self, py_tree, layer_idx, seq_len):
if seq_len % self.mini_batch_size == 0:
# copy last mini-batch states, clear gradients
for name in self.ttt_param_names:
self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
elif seq_len < self.mini_batch_size:
if seq_len != 1 and self.seqlen_offset > 0 and self.seqlen_offset % self.mini_batch_size != 0:
raise ValueError("fractional update not supported yet.")
if (seq_len + self.seqlen_offset) % self.mini_batch_size == 0:
# copy last mini-batch states, clear gradients
for name in self.ttt_param_names:
self.ttt_params_dict[f"{name}_states"][layer_idx].copy_(py_tree[f"{name}_states"])
self.ttt_params_dict[f"{name}_grad"][layer_idx].zero_()
else:
# copy gradients for the next update
for name in self.ttt_param_names:
self.ttt_params_dict[f"{name}_grad"][layer_idx].copy_(py_tree[f"{name}_grad"])
else:
raise ValueError(f"seq_len {seq_len} is a partial update not supported yet")
def ttt_params_to_dict(self, layer_idx):
return {name: self.ttt_params_dict[name][layer_idx] for name in self.ttt_params_dict}
class TTTBase(nn.Module):
def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.width = config.hidden_size
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.width // self.num_heads
self.mini_batch_size = config.mini_batch_size
# token_idx is a scale factor that scale the summation in Eqn. 4
token_idx = 1.0 / torch.arange(1, self.mini_batch_size + 1)
self.register_buffer("token_idx", token_idx, persistent=False)
# make the scale factor learnable
self.learnable_token_idx = nn.Parameter(torch.zeros((self.mini_batch_size,)))
self.share_qk = config.share_qk
self.conv_kernel = config.conv_kernel
self._init_qkvo_proj()
self._init_rope()
# Learnable eta in Sec. 2.7
self._init_ttt_lr_gate()
self._init_ttt_ln()
# use gating as in Mamba backbone
self.use_gate = config.use_gate
if self.use_gate:
self.g_proj = nn.Linear(self.width, self.width, bias=False)
self.post_norm = nn.LayerNorm(self.width, eps=1e-6)
def _init_qkvo_proj(self):
self.q_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
# we share Q/K projection when using Mamba backbone
if not self.share_qk:
self.k_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.width, self.num_heads * self.head_dim, bias=False)
# after share Q/K projection, we use different conv layers for Q and K
if self.share_qk:
self.conv_q = nn.Conv1d(
self.hidden_size,
self.hidden_size,
bias=True,
kernel_size=self.conv_kernel,
groups=self.hidden_size,
padding=self.conv_kernel - 1,
)
self.conv_k = nn.Conv1d(
self.hidden_size,
self.hidden_size,
bias=True,
kernel_size=self.conv_kernel,
groups=self.hidden_size,
padding=self.conv_kernel - 1,
)
def _init_rope(self):
self.rope_theta = self.config.rope_theta
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.mini_batch_size,
base=self.rope_theta,
)
def _init_ttt_lr_gate(self):
# [width, 1]
linear_weight_data = nn.Linear(self.width, 1, bias=True).weight.data
# prepending head dim -> [num_heads, width, 1]
self.learnable_ttt_lr_weight = nn.Parameter(
torch.stack(
[torch.normal(0, 0.02, size=linear_weight_data.shape) for _ in range(self.num_heads)],
dim=0,
)
)
linear_bias_data = nn.Linear(self.width, 1, bias=True).bias.data
# init bias to 0 following original JAX impl.
# [num_heads, 1]
self.learnable_ttt_lr_bias = nn.Parameter(
torch.stack(
[torch.zeros_like(linear_bias_data) for _ in range(self.num_heads)],
dim=0,
)
)
def _init_ttt_ln(self):
ln_weight_data = nn.LayerNorm(self.head_dim).weight.data
# prepending head dim -> [num_heads, width]
self.ttt_norm_weight = nn.Parameter(torch.tile(ln_weight_data.unsqueeze(0), (self.num_heads, 1)))
ln_bias_data = nn.LayerNorm(self.head_dim).bias.data
self.ttt_norm_bias = nn.Parameter(torch.tile(ln_bias_data.unsqueeze(0), (self.num_heads, 1)))
def get_qkv_projections(self, hidden_states, cache_params: Optional[TTTCache] = None):
if self.share_qk:
xq, XV = self.q_proj(hidden_states), self.v_proj(hidden_states)
seq_len = xq.shape[1]
xq = xq.transpose(1, 2)
if causal_conv1d_fn is None:
if cache_params is not None:
if cache_params.seqlen_offset > 0:
conv_q_state = cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx]
conv_q_state = torch.roll(conv_q_state, shifts=-1, dims=-1)
conv_q_state[:, :, -1] = xq[:, :, 0]
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
XQ = torch.sum(conv_q_state * self.conv_q.weight[:, 0, :], dim=-1)
XQ += self.conv_q.bias
XQ = XQ.unsqueeze(-1)
conv_k_state = cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx]
conv_k_state = torch.roll(conv_k_state, shifts=-1, dims=-1)
conv_k_state[:, :, -1] = xq[:, :, 0]
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
XK = torch.sum(conv_k_state * self.conv_k.weight[:, 0, :], dim=-1)
XK += self.conv_k.bias
XK = XK.unsqueeze(-1)
else:
conv_q_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_state)
XQ = self.conv_q(xq)[..., :seq_len]
conv_k_state = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_state)
XK = self.conv_k(xq)[..., :seq_len]
else:
XQ = self.conv_q(xq)[..., :seq_len]
XK = self.conv_k(xq)[..., :seq_len]
else:
conv_q_weights = self.conv_q.weight.view(self.conv_q.weight.size(0), self.conv_q.weight.size(2))
conv_k_weights = self.conv_k.weight.view(self.conv_k.weight.size(0), self.conv_k.weight.size(2))
if cache_params is not None and cache_params.seqlen_offset > 0:
XQ = causal_conv1d_update(
xq.squeeze(-1),
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx],
conv_q_weights,
self.conv_q.bias,
None,
)
XQ = XQ.unsqueeze(-1)
XK = causal_conv1d_update(
xq.squeeze(-1),
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx],
conv_k_weights,
self.conv_k.bias,
None,
)
XK = XK.unsqueeze(-1)
else:
if cache_params is not None:
conv_q_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_q"][self.layer_idx].copy_(conv_q_states)
conv_k_states = nn.functional.pad(xq, (self.config.conv_kernel - xq.shape[-1], 0))
cache_params.conv_states_dic["ttt_conv_k"][self.layer_idx].copy_(conv_k_states)
XQ = causal_conv1d_fn(xq, conv_q_weights, self.conv_q.bias, activation=None)
XK = causal_conv1d_fn(xq, conv_k_weights, self.conv_k.bias, activation=None)
XQ = XQ.transpose(1, 2)
XK = XK.transpose(1, 2)
else:
XQ, XK, XV = (
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
)
return XQ, XK, XV
def _split_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
def get_eta(self, X, mini_batch_step_offset, mini_batch_size):
# [B, num_heads, num_mini_batch, mini_batch_size, 1]
ttt_lr = torch.einsum("bnkc,hdc->bhnkd", X, self.learnable_ttt_lr_weight) + self.learnable_ttt_lr_bias.reshape(
1, -1, 1, 1, 1
)
ttt_lr = F.sigmoid(ttt_lr)
# [B, num_heads, num_mini_batch, 1, mini_batch_size]
ttt_lr = ttt_lr.permute(0, 1, 2, 4, 3)
ttt_lr_eta = self.config.ttt_base_lr * ttt_lr / self.head_dim
# [B, L]
token_idx = self.token_idx + self.learnable_token_idx
token_idx = token_idx[mini_batch_step_offset : mini_batch_step_offset + mini_batch_size]
# token idx should be greast than 0
token_idx = torch.clamp_min(token_idx, 0.0)
# NOTE: token_eta is a scale factor that applies to each token in the mini-batch
# [B, num_heads, num_mini_batch, mini_batch_size, 1]
token_eta = torch.broadcast_to(
token_idx.reshape(1, 1, 1, mini_batch_size, 1),
(X.shape[0], self.num_heads, X.shape[1], mini_batch_size, 1),
)
return token_eta, ttt_lr_eta
def apply_gate(self, hidden_states, ttt_output):
y = self.g_proj(hidden_states)
# use 'tanh' approximation for matching JAX impl.
y = F.gelu(y, approximate="tanh")
output = y * ttt_output
return output
def get_ttt_inputs(self, inputs, mini_batch_size, cache_params):
XQ = inputs["XQ"]
XK = inputs["XK"]
XV = inputs["XV"]
X = inputs["X"]
B, L, C = X.shape
num_mini_batch = L // mini_batch_size
# [B ,num_mini_batch, mini_batch_size, C]
X = X.reshape(B, num_mini_batch, mini_batch_size, self.width)
XQ = XQ.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
XK = XK.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
XV = XV.reshape(B, self.num_heads, L // mini_batch_size, mini_batch_size, self.head_dim)
if cache_params is not None:
mini_batch_step_offset = cache_params.seqlen_offset % self.mini_batch_size
else:
mini_batch_step_offset = 0
token_eta, ttt_lr_eta = self.get_eta(X, mini_batch_step_offset, mini_batch_size)
eta = token_eta * ttt_lr_eta
# decouple token_coeff and ilr_coeff for decoding
inputs = {
"XQ": XQ,
"XK": XK,
"XV": XV,
"eta": eta,
"token_eta": token_eta,
"ttt_lr_eta": ttt_lr_eta,
}
return inputs
def ttt(
self,
inputs,
mini_batch_size,
last_mini_batch_params_dict,
cache_params: Optional[TTTCache] = None,
):
raise NotImplementedError("ttt method must be implemented in TTTBase subclasses.")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_params: Optional[TTTCache] = None,
):
B, L = hidden_states.shape[:2]
reminder_len = L % self.mini_batch_size
num_mini_batch = L // self.mini_batch_size
last_mini_batch_params_dict = None
XQ, XK, XV = self.get_qkv_projections(hidden_states, cache_params=cache_params)
# [B, L, C] -> [B, L, num_heads, head_dim] -> [B, num_heads, L, head_dim]
XQ = XQ.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
XK = XK.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
XV = XV.reshape(B, L, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(XV, position_ids % self.mini_batch_size)
# permute_qk and undo_permute_qk is just for aligning pytorch with jax pre-training
XQ, XK = permute_qk(XQ, XK)
XQ, XK = apply_rotary_pos_emb(XQ, XK, cos, sin)
XQ, XK = undo_permute_qk(XQ, XK)
output_hidden_states = []
# when input sequence length is not a multiple of mini_batch_size
# we need to compute them seperately, when computing the reminder,
# we will need the last_mini_batch_params_dict to continue TTT learning
if num_mini_batch > 0:
inputs = {
"XQ": XQ[:, :, : num_mini_batch * self.mini_batch_size],
"XK": XK[:, :, : num_mini_batch * self.mini_batch_size],
"XV": XV[:, :, : num_mini_batch * self.mini_batch_size],
"X": hidden_states[:, : num_mini_batch * self.mini_batch_size],
}
output_mod, last_mini_batch_params_dict = self.ttt(
self.get_ttt_inputs(inputs, self.mini_batch_size, cache_params),
mini_batch_size=self.mini_batch_size,
last_mini_batch_params_dict=last_mini_batch_params_dict,
cache_params=cache_params,
)
output_hidden_states.append(output_mod)
if reminder_len > 0:
inputs = {
"XQ": XQ[:, :, -reminder_len:],
"XK": XK[:, :, -reminder_len:],
"XV": XV[:, :, -reminder_len:],
"X": hidden_states[:, -reminder_len:],
}
output_reminder, _ = self.ttt(
self.get_ttt_inputs(inputs, reminder_len, cache_params),
mini_batch_size=reminder_len,
last_mini_batch_params_dict=last_mini_batch_params_dict,
cache_params=cache_params,
)
output_hidden_states.append(output_reminder)
output_hidden_states = torch.cat(output_hidden_states, dim=1)
output_hidden_states = self.post_norm(output_hidden_states)
if self.use_gate:
output_hidden_states = self.apply_gate(hidden_states, output_hidden_states)
output_hidden_states = self.o_proj(output_hidden_states)
return output_hidden_states
class TTTLinear(TTTBase):
def __init__(self, config: TTTConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
# TTT model initialization for TTT-Linear
self.W1 = nn.Parameter(torch.normal(0, 0.02, size=(self.num_heads, self.head_dim, self.head_dim)))
self.b1 = nn.Parameter(torch.zeros(self.num_heads, 1, self.head_dim))
def ttt(
self,
inputs,
mini_batch_size,
last_mini_batch_params_dict,
cache_params: Optional[TTTCache] = None,
):
if mini_batch_size is None:
mini_batch_size = self.mini_batch_size
# in this case, we are decoding
if last_mini_batch_params_dict is None and cache_params is not None:
last_mini_batch_params_dict = cache_params.ttt_params_to_dict(self.layer_idx)
# [B, num_heads, num_mini_batch, mini_batch_size, head_dim]
B = inputs["XV"].shape[0]
num_mini_batch = inputs["XV"].shape[2]
L = inputs["XV"].shape[2] * inputs["XV"].shape[3]
device = inputs["XV"].device
dtype = inputs["XV"].dtype
# NOTE:
# for prefilling, we will always use dual form for faster computation
# we need to use primal form if mini_batch_size is not a multiple of self.mini_batch_size
# since we need store the gradient for the next mini-batch computation
use_dual_form = cache_params is None or mini_batch_size % self.mini_batch_size == 0
def compute_mini_batch(params_dict, inputs):
# [B, nh, f, f], nh=num_heads, f=head_dim
W1_init = params_dict["W1_states"]
# [B, nh, 1, f]
b1_init = params_dict["b1_states"]
# [B,nh,K,f], K=mini_batch_size
XQ_mini_batch = inputs["XQ"]
XV_mini_batch = inputs["XV"]
XK_mini_batch = inputs["XK"]
# [B, nh, K, 1]
eta_mini_batch = inputs["eta"]
token_eta_mini_batch = inputs["token_eta"]
ttt_lr_eta_mini_batch = inputs["ttt_lr_eta"]
X1 = XK_mini_batch
# [B,nh,K,f] @ [B,nh,f,f] -> [B,nh,K,f]
Z1 = X1 @ W1_init + b1_init
reconstruction_target = XV_mini_batch - XK_mini_batch
ln_weight = self.ttt_norm_weight.reshape(self.num_heads, 1, self.head_dim)
ln_bias = self.ttt_norm_bias.reshape(self.num_heads, 1, self.head_dim)
# [B,nh,K,f]
grad_l_wrt_Z1 = ln_fused_l2_bwd(Z1, reconstruction_target, ln_weight, ln_bias)
if use_dual_form:
# [B,nh,K,K]
Attn1 = torch.tril(XQ_mini_batch @ X1.transpose(-2, -1))
# [B,nh,1,f] - [B,nh,K,K] @ [B,nh,K,f] -> [B,nh,K,f]
b1_bar = b1_init - torch.tril(eta_mini_batch) @ grad_l_wrt_Z1
# [B,nh,K,f] @ [B,nh,f,f] - ([B,nh,K,1] * [B,nh,K,K]) @ [B,nh,K,f] + [B,nh,K,f]
Z1_bar = XQ_mini_batch @ W1_init - (eta_mini_batch * Attn1) @ grad_l_wrt_Z1 + b1_bar
last_eta_mini_batch = eta_mini_batch[:, :, -1, :, None]
# [B,nh,f,f] - [B,nh,f,K] @ [B,nh,K,f]
W1_last = W1_init - (last_eta_mini_batch * X1).transpose(-1, -2) @ grad_l_wrt_Z1
# [B,nh,1,f]
b1_last = b1_init - torch.sum(last_eta_mini_batch * grad_l_wrt_Z1, dim=-2, keepdim=True)
grad_W1_last = torch.zeros_like(W1_last)
grad_b1_last = torch.zeros_like(b1_last)
else:
ttt_lr_eta_mini_batch = torch.broadcast_to(
ttt_lr_eta_mini_batch,
(
*ttt_lr_eta_mini_batch.shape[:2],
mini_batch_size,
mini_batch_size,
),
)