Skip to content

Commit

Permalink
end to end in readme
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 8, 2024
1 parent 52095e3 commit b857957
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,33 @@

Quick implementation of <a href="https://arxiv.org/abs/2410.01131">nGPT</a>, learning entirely on the hypersphere, from NvidiaAI. The question is whether there is any loss of expressivity they swept under the rug, but I'll take it with good faith.

## Install

```bash
$ pip install nGPT-pytorch
```

## Usage

```python
import torch
from nGPT_pytorch import nGPT

model = nGPT(
num_tokens = 256,
dim = 512,
depth = 4,
attn_norm_qk = True
)

x = torch.randint(0, 256, (2, 2048))

loss = model(x, return_loss = True)
loss.backward()

logits = model(x) # (2, 2048, 256)
```

## Citations

```bibtex
Expand Down
20 changes: 18 additions & 2 deletions nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from einops import einsum, rearrange
from einops.layers.torch import Rearrange

from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb

# functions

def exists(v):
Expand Down Expand Up @@ -62,14 +64,17 @@ def __init__(
*,
dim_head = 64,
heads = 8,
norm_qk = False
norm_qk = True
):
super().__init__()
dim_inner = dim_head * heads
self.to_q = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_k = NormLinear(dim, dim_inner, norm_dim = 0)
self.to_v = NormLinear(dim, dim_inner, norm_dim = 0)

self.rotary_emb = RotaryEmbedding(dim_head)
self.qk_scale = nn.Parameter(torch.ones(dim_head) * (dim_head ** -0.25))

self.norm_qk = norm_qk
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
Expand All @@ -84,9 +89,20 @@ def forward(

q, k, v = map(self.split_heads, (q, k, v))

# maybe query key norm

if self.norm_qk:
q, k = map(l2norm, (q, k))

# scaling queries and keys - this would line up with the popular use of qk rmsnorm from google deepmind and now black forest labs

q, k = (q * self.qk_scale), (k * self.qk_scale)

# rotary positions

q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)

# scale is 1., as scaling factor is moved to s_qk (dk ^ 0.25) - eq. 16

out = F.scaled_dot_product_attention(
Expand Down Expand Up @@ -127,7 +143,7 @@ def __init__(
depth,
dim_head = 64,
heads = 8,
attn_norm_qk = False, # they say the query/key normalization is optional
attn_norm_qk = True, # they say the query/key normalization is optional
ff_expand_factor = 4.,
ce_ignore_index = -1,
residual_lerp_scale_init = None
Expand Down

0 comments on commit b857957

Please sign in to comment.