Skip to content

Commit

Permalink
enforce torch 2.4 and use latest sdpa for attn
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 8, 2024
1 parent d88d7a2 commit 52095e3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
47 changes: 35 additions & 12 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.utils.parametrize as parametrize

from einops import einsum, rearrange
from einops.layers.torch import Rearrange

# functions

Expand Down Expand Up @@ -60,21 +61,42 @@ def __init__(
dim,
*,
dim_head = 64,
heads = 8
heads = 8,
norm_qk = False
):
super().__init__()
dim_inner = dim_head * heads
self.query_weights = NormLinear(dim, dim_inner, norm_dim = 0)
self.key_weights = NormLinear(dim, dim_inner, norm_dim = 0)
self.value_weights = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_q = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_k = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_v = NormLinear(dim, dim_inner, norm_dim = 0)

self.norm_qk = norm_qk
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')

self.out_weights = NormLinear(dim_inner, dim)
self.to_out = NormLinear(dim_inner, dim)

def forward(
self,
x
):
return x
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

q, k, v = map(self.split_heads, (q, k, v))

if self.norm_qk:
q, k = map(l2norm, (q, k))

# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16

out = F.scaled_dot_product_attention(
q, k, v,
is_causal = True,
scale = 1.
)

out = self.merge_heads(out)
return self.to_out(out)

class FeedForward(Module):
def __init__(
Expand All @@ -85,14 +107,14 @@ def __init__(
):
super().__init__()
dim_inner = int(dim * expand_factor * 2 / 3)
self.proj_in = NormLinear(dim, dim_inner, norm_dim = 0)
self.gate = NormLinear(dim, dim_inner, norm_dim = 0)
self.proj_out = NormLinear(dim_inner, dim)
self.to_hidden = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_gate = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_out = NormLinear(dim_inner, dim)

def forward(self, x):
x, gate = self.proj_in(x), self.gate(x)
x, gate = self.to_hidden(x), self.to_gate(x)
x = F.silu(gate) * x
return self.proj_out(x)
return self.to_out(x)

# classes

Expand All @@ -105,6 +127,7 @@ def __init__(
depth,
dim_head = 64,
heads = 8,
attn_norm_qk = False, # they say the query/key normalization is optional
ff_expand_factor = 4.,
ce_ignore_index = -1,
residual_lerp_scale_init = None
Expand All @@ -120,7 +143,7 @@ def __init__(

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, dim_head = dim_head, heads = heads),
Attention(dim, dim_head = dim_head, heads = heads, norm_qk = attn_norm_qk),
FeedForward(dim, expand_factor = ff_expand_factor),
]))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers=[
dependencies = [
"einops>=0.8.0",
"rotary_embedding_torch",
"torch>=2.0",
"torch>=2.4",
]

[project.urls]
Expand Down

0 comments on commit 52095e3

Please sign in to comment.