From fcc4b167270ff3981b818e7c3236d1b1ae234b7e Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 2 Nov 2024 10:12:01 -0700 Subject: [PATCH] address https://github.com/lucidrains/nGPT-pytorch/issues/11 with a hparam --- nGPT_pytorch/nGPT.py | 10 +++++++++- nGPT_pytorch/nGPTExperimental.py | 12 ++++++++++-- pyproject.toml | 2 +- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/nGPT_pytorch/nGPT.py b/nGPT_pytorch/nGPT.py index f90578a..3718831 100644 --- a/nGPT_pytorch/nGPT.py +++ b/nGPT_pytorch/nGPT.py @@ -191,7 +191,8 @@ def __init__( enable_mem_efficient = True ), norm_eps = 0., - num_hyperspheres = 1 + num_hyperspheres = 1, + mask_value: float | None = None ): super().__init__() self.heads = heads @@ -214,6 +215,8 @@ def __init__( sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable] self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends) + self.attn_mask_value = attn_mask_value + # qk rmsnorm + scale self.norm_qk = norm_qk @@ -263,6 +266,9 @@ def forward( if exists(mask): mask = rearrange(mask, 'b j -> b 1 1 j') + if exists(self.mask_value): + mask = mask * self.mask_value + # scale is sqrt(dk) with self.sdpa_context_manager(): @@ -339,6 +345,7 @@ def __init__( num_hyperspheres = 1, causal = True, add_value_residual = True, + attn_mask_value: float | None = None, # address some issue with sdpa # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified s_logit_init: float = 1., @@ -414,6 +421,7 @@ def __init__( s_qk_init = s_qk_init_, s_qk_scale = s_qk_scale_, flash_kwargs = attn_flash_kwargs, + mask_value = attn_mask_value, norm_eps = norm_eps, num_hyperspheres = num_hyperspheres ) diff --git a/nGPT_pytorch/nGPTExperimental.py b/nGPT_pytorch/nGPTExperimental.py index 5c3c80d..9c7d078 100644 --- a/nGPT_pytorch/nGPTExperimental.py +++ b/nGPT_pytorch/nGPTExperimental.py @@ -190,7 +190,8 @@ def __init__( enable_mem_efficient = True ), norm_eps = 0., - num_hyperspheres = 1 + num_hyperspheres = 1, + mask_value = None ): super().__init__() self.heads = heads @@ -213,6 +214,8 @@ def __init__( sdpa_backends = [SDP_BACKEND_MAP[enable_str] for enable_str, enable in flash_kwargs.items() if enable] self.sdpa_context_manager = partial(torch.nn.attention.sdpa_kernel, sdpa_backends) + self.mask_value = mask_value + # rotary self.rotary_emb = RotaryEmbedding(dim_head) @@ -263,6 +266,9 @@ def forward( if exists(mask): mask = rearrange(mask, 'b j -> b 1 1 j') + if exists(self.mask_value): + mask = mask * self.mask_value + # scale is sqrt(dk) with self.sdpa_context_manager(): @@ -335,6 +341,7 @@ def __init__( tied_embedding = False, num_hyperspheres = 1, causal = True, + attn_mask_value: float | None = None, # below are all the scale related hyperparameters, for controlling effective relative learning rates throughout the network alpha_init: float | None = None, # this would set the alpha init for all residuals, but would be overridden by alpha_attn_init and alpha_ff_init if they are specified s_logit_init: float = 1., @@ -407,7 +414,8 @@ def __init__( s_qk_scale = s_qk_scale_, flash_kwargs = attn_flash_kwargs, norm_eps = norm_eps, - num_hyperspheres = num_hyperspheres + num_hyperspheres = num_hyperspheres, + mask_value = attn_mask_value ) ff = FeedForward( diff --git a/pyproject.toml b/pyproject.toml index 0ab1c2c..3639785 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nGPT-pytorch" -version = "0.2.2" +version = "0.2.3" description = "nGPT" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }