Skip to content

Commit

Permalink
fix the multi-headed qk rmsnorm scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2024
1 parent e508997 commit 5a41fbd
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
super().__init__()
self.linear = nn.Linear(dim, dim_out, bias = False)

self.scale = groups ** -1
self.parametrize = parametrize
self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps, groups = groups)

Expand Down Expand Up @@ -134,7 +135,7 @@ def weight(self):
return self.linear.weight

def forward(self, x):
return self.linear(x)
return self.linear(x) * self.scale

# attention

Expand All @@ -159,6 +160,7 @@ def __init__(
num_hyperspheres = 1
):
super().__init__()
self.heads = heads
self.causal = causal

NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres)
Expand Down Expand Up @@ -200,11 +202,6 @@ def forward(
):
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

# scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs

q = q * self.q_scale()
k = k * self.k_scale()

# split heads

q, k, v = map(self.split_heads, (q, k, v))
Expand All @@ -214,6 +211,11 @@ def forward(
if self.norm_qk:
q, k = map(self.l2norm, (q, k))

# scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs - will use multihead rmsnorm

q = q * rearrange(self.q_scale(), '(h d) -> h 1 d', h = self.heads)
k = k * rearrange(self.k_scale(), '(h d) -> h 1 d', h = self.heads)

# rotary positions

q = self.rotary_emb.rotate_queries_or_keys(q)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.1.2"
version = "0.1.4"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 5a41fbd

Please sign in to comment.