From b8579579f90c70031b3c0c8649192426e72ce3ba Mon Sep 17 00:00:00 2001 From: lucidrains Date: Tue, 8 Oct 2024 09:37:39 -0700 Subject: [PATCH] end to end in readme --- README.md | 27 +++++++++++++++++++++++++++ nGPT_pytorch/nGPT.py | 20 ++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4fa414a..0d7aca2 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,33 @@ Quick implementation of nGPT, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith. +## Install + +```bash +$ pip install nGPT-pytorch +``` + +## Usage + +```python +import torch +from nGPT_pytorch import nGPT + +model = nGPT( + num_tokens = 256, + dim = 512, + depth = 4, + attn_norm_qk = True +) + +x = torch.randint(0, 256, (2, 2048)) + +loss = model(x, return_loss = True) +loss.backward() + +logits = model(x) # (2, 2048, 256) +``` + ## Citations ```bibtex diff --git a/nGPT_pytorch/nGPT.py b/nGPT_pytorch/nGPT.py index f02605c..36c5ed7 100644 --- a/nGPT_pytorch/nGPT.py +++ b/nGPT_pytorch/nGPT.py @@ -9,6 +9,8 @@ from einops import einsum, rearrange from einops.layers.torch import Rearrange +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb + # functions def exists(v): @@ -62,7 +64,7 @@ def __init__( *, dim_head = 64, heads = 8, - norm_qk = False + norm_qk = True ): super().__init__() dim_inner = dim_head * heads @@ -70,6 +72,9 @@ def __init__( self.to_k = NormLinear(dim, dim_inner, norm_dim = 0) self.to_v = NormLinear(dim, dim_inner, norm_dim = 0) + self.rotary_emb = RotaryEmbedding(dim_head) + self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** -0.25)) + self.norm_qk = norm_qk self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) self.merge_heads = Rearrange('b h n d -> b n (h d)') @@ -84,9 +89,20 @@ def forward( q, k, v = map(self.split_heads, (q, k, v)) + # maybe query key norm + if self.norm_qk: q, k = map(l2norm, (q, k)) + # scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs + + q, k = (q * self.qk_scale), (k * self.qk_scale) + + # rotary positions + + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + # scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16 out = F.scaled_dot_product_attention( @@ -127,7 +143,7 @@ def __init__( depth, dim_head = 64, heads = 8, - attn_norm_qk = False, # they say the query/key normalization is optional + attn_norm_qk = True, # they say the query/key normalization is optional ff_expand_factor = 4., ce_ignore_index = -1, residual_lerp_scale_init = None