Skip to content

Commit

Permalink
address #11 with a hparam
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 2, 2024
1 parent 6129dee commit fcc4b16
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
10 changes: 9 additions & 1 deletion nGPT_pytorch/nGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def __init__(
enable_mem_efficient = True
),
norm_eps = 0.,
num_hyperspheres = 1
num_hyperspheres = 1,
mask_value: float | None = None
):
super().__init__()
self.heads = heads
Expand All @@ -214,6 +215,8 @@ def __init__(
sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable]
self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)

self.attn_mask_value = attn_mask_value

# qk rmsnorm + scale

self.norm_qk = norm_qk
Expand Down Expand Up @@ -263,6 +266,9 @@ def forward(
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')

if exists(self.mask_value):
mask = mask * self.mask_value

# scale is sqrt(dk)

with self.sdpa_context_manager():
Expand Down Expand Up @@ -339,6 +345,7 @@ def __init__(
num_hyperspheres = 1,
causal = True,
add_value_residual = True,
attn_mask_value: float | None = None, # address some issue with sdpa
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
s_logit_init: float = 1.,
Expand Down Expand Up @@ -414,6 +421,7 @@ def __init__(
s_qk_init = s_qk_init_,
s_qk_scale = s_qk_scale_,
flash_kwargs = attn_flash_kwargs,
mask_value = attn_mask_value,
norm_eps = norm_eps,
num_hyperspheres = num_hyperspheres
)
Expand Down
12 changes: 10 additions & 2 deletions nGPT_pytorch/nGPTExperimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def __init__(
enable_mem_efficient = True
),
norm_eps = 0.,
num_hyperspheres = 1
num_hyperspheres = 1,
mask_value = None
):
super().__init__()
self.heads = heads
Expand All @@ -213,6 +214,8 @@ def __init__(
sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable]
self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends)

self.mask_value = mask_value

# rotary

self.rotary_emb = RotaryEmbedding(dim_head)
Expand Down Expand Up @@ -263,6 +266,9 @@ def forward(
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')

if exists(self.mask_value):
mask = mask * self.mask_value

# scale is sqrt(dk)

with self.sdpa_context_manager():
Expand Down Expand Up @@ -335,6 +341,7 @@ def __init__(
tied_embedding = False,
num_hyperspheres = 1,
causal = True,
attn_mask_value: float | None = None,
# below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network
alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified
s_logit_init: float = 1.,
Expand Down Expand Up @@ -407,7 +414,8 @@ def __init__(
s_qk_scale = s_qk_scale_,
flash_kwargs = attn_flash_kwargs,
norm_eps = norm_eps,
num_hyperspheres = num_hyperspheres
num_hyperspheres = num_hyperspheres,
mask_value = attn_mask_value
)

ff = FeedForward(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "nGPT-pytorch"
version = "0.2.2"
version = "0.2.3"
description = "nGPT"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit fcc4b16

Please sign in to comment.