From 5a41fbddc1bfe3f5f2cbb48fdfe0b135be7816c9 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 10 Oct 2024 19:12:45 -0700 Subject: [PATCH] fix the multi-headed qk rmsnorm scaling --- nGPT_pytorch/nGPT.py | 14 ++++++++------ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/nGPT_pytorch/nGPT.py b/nGPT_pytorch/nGPT.py index a684c02..3a2a740 100644 --- a/nGPT_pytorch/nGPT.py +++ b/nGPT_pytorch/nGPT.py @@ -107,6 +107,7 @@ def __init__( super().__init__() self.linear = nn.Linear(dim, dim_out, bias = False) + self.scale = groups ** -1 self.parametrize = parametrize self.l2norm = L2Norm(dim = -1 if norm_dim_in else 0, norm_eps = norm_eps, groups = groups) @@ -134,7 +135,7 @@ def weight(self): return self.linear.weight def forward(self, x): - return self.linear(x) + return self.linear(x) * self.scale # attention @@ -159,6 +160,7 @@ def __init__( num_hyperspheres = 1 ): super().__init__() + self.heads = heads self.causal = causal NormLinear_ = partial(NormLinear, parametrize = not manual_norm_weights, norm_eps = norm_eps, groups = num_hyperspheres) @@ -200,11 +202,6 @@ def forward( ): q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) - # scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs - - q = q * self.q_scale() - k = k * self.k_scale() - # split heads q, k, v = map(self.split_heads, (q, k, v)) @@ -214,6 +211,11 @@ def forward( if self.norm_qk: q, k = map(self.l2norm, (q, k)) + # scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs - will use multihead rmsnorm + + q = q * rearrange(self.q_scale(), '(h d) -> h 1 d', h = self.heads) + k = k * rearrange(self.k_scale(), '(h d) -> h 1 d', h = self.heads) + # rotary positions q = self.rotary_emb.rotate_queries_or_keys(q) diff --git a/pyproject.toml b/pyproject.toml index f9ccb75..bcb1a3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nGPT-pytorch" -version = "0.1.2" +version = "0.1.4" description = "nGPT" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }