From 1fa2760e7804eb886e354d64f836597ed77cf330 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Aug 2023 22:59:39 +0000 Subject: [PATCH 1/6] add rmsnorm class --- .../modules/layers/normalizations.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index b4fc64b7..acb610fb 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any - +import torch from torch import nn, Tensor @@ -45,3 +45,25 @@ def forward(self, x: Tensor) -> Tensor: self.eps, ) return output.type_as(x) + + +class RMSNorm(nn.Module): + """Root Mean Square layer normalization + as proposed in: https://arxiv.org/abs/1910.07467 + + params: + dim = model size + eps = epsilon + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + x_normed = self._norm(x.float()).type_as(x) + return x_normed * self.scale From 5129c67c3ab949e1c21a58457afccfdd0498d7e1 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Aug 2023 22:59:41 +0000 Subject: [PATCH 2/6] add rms norm --- torchmultimodal/modules/layers/normalizations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index acb610fb..9d8bc809 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Any + import torch from torch import nn, Tensor From 55416980d264a1b77deffc8b57f1650531ae26d2 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Aug 2023 23:31:04 +0000 Subject: [PATCH 3/6] add rms tests --- tests/modules/layers/test_normalizations.py | 52 ++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index e9ef3817..fea46a82 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -5,7 +5,15 @@ # LICENSE file in the root directory of this source tree. import torch -from torchmultimodal.modules.layers.normalizations import Fp32GroupNorm, Fp32LayerNorm +import torch.nn.functional as F + +from tests.test_utils import gpu_test + +from torchmultimodal.modules.layers.normalizations import ( + Fp32GroupNorm, + Fp32LayerNorm, + RMSNorm, +) def test_fp32layernorm(): @@ -20,3 +28,45 @@ def test_fp32groupnorm(): norm = Fp32GroupNorm(2, 4) output = norm(x) assert output.dtype == torch.float16 + + +def test_rms_norm_fp32return(): + """verify type is returned as fp32""" + dims = 512 + x = torch.empty(dims, dtype=torch.float16) + norm = RMSNorm( + dims, + ) + output = norm(x) + assert output.dtype == torch.float32 + + +@gpu_test(1) +def test_rms_norm_core_algo(): + """compare RMSNorm with RMSNorm using F.norm version""" + + dims = 1024 + x = torch.empty(dims, dtype=torch.float16, device="cuda") + x_clone = x.clone().detach() + + class RMSNormFunctional(torch.nn.Module): + def __init__(self, dim, eps=1e-6): + super().__init__() + self.scale = dim**0.5 + self.weights = torch.nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + return F.normalize(x, p=2, dim=-1, eps=self.eps) * self.scale * self.weights + + base_norm = RMSNorm( + dims, + ).to("cuda") + backup_norm = RMSNormFunctional( + dims, + ).to("cuda") + + output_base_rms = base_norm(x) + output_backup_rms = backup_norm(x_clone) + + assert torch.allclose(output_base_rms, output_backup_rms) From d15037f8661991b681498dfa7b4dae71e1c4a0ca Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Aug 2023 23:56:17 +0000 Subject: [PATCH 4/6] merge unit tests into one, add Tensor typing --- torchmultimodal/modules/layers/normalizations.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index 9d8bc809..1afb6fbd 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -49,9 +49,11 @@ def forward(self, x: Tensor) -> Tensor: class RMSNorm(nn.Module): - """Root Mean Square layer normalization + """Root Mean Square Layer Normalization as proposed in: https://arxiv.org/abs/1910.07467 + Calcs are done in fp32. + params: dim = model size eps = epsilon @@ -62,9 +64,9 @@ def __init__(self, dim: int, eps: float = 1e-6): self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) - def _norm(self, x): + def _norm(self, x: Tensor): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): + def forward(self, x: Tensor): x_normed = self._norm(x.float()).type_as(x) return x_normed * self.scale From 4073e65840f3371dbed310e3307ab64e883c8c9a Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 16 Aug 2023 23:57:27 +0000 Subject: [PATCH 5/6] add bit more info to doc_string --- tests/modules/layers/test_normalizations.py | 16 +++------------- torchmultimodal/modules/layers/normalizations.py | 6 +++--- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index fea46a82..c86aaa45 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -30,17 +30,6 @@ def test_fp32groupnorm(): assert output.dtype == torch.float16 -def test_rms_norm_fp32return(): - """verify type is returned as fp32""" - dims = 512 - x = torch.empty(dims, dtype=torch.float16) - norm = RMSNorm( - dims, - ) - output = norm(x) - assert output.dtype == torch.float32 - - @gpu_test(1) def test_rms_norm_core_algo(): """compare RMSNorm with RMSNorm using F.norm version""" @@ -66,7 +55,8 @@ def forward(self, x): dims, ).to("cuda") - output_base_rms = base_norm(x) + output_core_rms = base_norm(x) output_backup_rms = backup_norm(x_clone) - assert torch.allclose(output_base_rms, output_backup_rms) + assert torch.allclose(output_core_rms, output_backup_rms) + assert output_core_rms.dtype == torch.float32 diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index 1afb6fbd..d312bc49 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -54,9 +54,9 @@ class RMSNorm(nn.Module): Calcs are done in fp32. - params: - dim = model size - eps = epsilon + Params: + dim = model size + eps = epsilon """ def __init__(self, dim: int, eps: float = 1e-6): From fa99b34c8812d2cc517f923820ad014ad243a2e2 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 17 Aug 2023 01:43:04 +0000 Subject: [PATCH 6/6] move to fixed tensor tests --- tests/modules/layers/test_normalizations.py | 63 ++++++++++--------- .../modules/layers/normalizations.py | 12 ++-- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/tests/modules/layers/test_normalizations.py b/tests/modules/layers/test_normalizations.py index c86aaa45..ed16d052 100644 --- a/tests/modules/layers/test_normalizations.py +++ b/tests/modules/layers/test_normalizations.py @@ -5,9 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -import torch.nn.functional as F - -from tests.test_utils import gpu_test +from tests.test_utils import assert_expected from torchmultimodal.modules.layers.normalizations import ( Fp32GroupNorm, @@ -30,33 +28,36 @@ def test_fp32groupnorm(): assert output.dtype == torch.float16 -@gpu_test(1) def test_rms_norm_core_algo(): """compare RMSNorm with RMSNorm using F.norm version""" - - dims = 1024 - x = torch.empty(dims, dtype=torch.float16, device="cuda") - x_clone = x.clone().detach() - - class RMSNormFunctional(torch.nn.Module): - def __init__(self, dim, eps=1e-6): - super().__init__() - self.scale = dim**0.5 - self.weights = torch.nn.Parameter(torch.ones(dim)) - self.eps = eps - - def forward(self, x): - return F.normalize(x, p=2, dim=-1, eps=self.eps) * self.scale * self.weights - - base_norm = RMSNorm( - dims, - ).to("cuda") - backup_norm = RMSNormFunctional( - dims, - ).to("cuda") - - output_core_rms = base_norm(x) - output_backup_rms = backup_norm(x_clone) - - assert torch.allclose(output_core_rms, output_backup_rms) - assert output_core_rms.dtype == torch.float32 + dims = 10 + rms_norm = RMSNorm(dims) + + input_ones = torch.ones(dims, dtype=torch.float) + + input_fixed = torch.tensor( + [0.999, 1.1111, 2.222, 3.333, 4.444, 5.555, 6.678, 7.987, 8.123, 9.101010], + dtype=torch.float16, + ) + fixed_expected = torch.tensor( + [ + 0.1749, + 0.1946, + 0.3892, + 0.5835, + 0.7783, + 0.9727, + 1.1699, + 1.3984, + 1.4229, + 1.5938, + ], + dtype=torch.float, + ) + + output_fixed = rms_norm(input_fixed) + output_ones = rms_norm(input_ones) + + assert_expected(output_ones, input_ones) + assert_expected(output_fixed, fixed_expected, atol=1e-04, rtol=1e-05) + assert output_fixed.dtype == torch.float32 diff --git a/torchmultimodal/modules/layers/normalizations.py b/torchmultimodal/modules/layers/normalizations.py index d312bc49..63623737 100644 --- a/torchmultimodal/modules/layers/normalizations.py +++ b/torchmultimodal/modules/layers/normalizations.py @@ -54,9 +54,11 @@ class RMSNorm(nn.Module): Calcs are done in fp32. - Params: - dim = model size - eps = epsilon + original impl: https://github.com/facebookresearch/llama/blob/main/llama/model.py + + Args: + dim(int) = model size + eps(float) = epsilon """ def __init__(self, dim: int, eps: float = 1e-6): @@ -64,9 +66,9 @@ def __init__(self, dim: int, eps: float = 1e-6): self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) - def _norm(self, x: Tensor): + def _norm(self, x: Tensor) -> Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: x_normed = self._norm(x.float()).type_as(x) return x_normed * self.scale