From d9a934b6e5eff5f02a4bb11160fbfc51736741d7 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Thu, 24 Oct 2024 21:02:52 -0700 Subject: [PATCH 1/9] Feat (float): add bit_width to proxy --- src/brevitas/proxy/float_parameter_quant.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/brevitas/proxy/float_parameter_quant.py b/src/brevitas/proxy/float_parameter_quant.py index 1e886ee6b..88e78beaa 100644 --- a/src/brevitas/proxy/float_parameter_quant.py +++ b/src/brevitas/proxy/float_parameter_quant.py @@ -14,6 +14,13 @@ class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC): + def bit_width(self): + if not self.is_quant_enabled: + return None + x = self.__call__(self.tracked_parameter_list[0]) + bit_width = x.mantissa_bit_width + x.exponent_bit_width + 1 + return bit_width + def scale(self): if not self.is_quant_enabled: return None From 23589d1563175fd50216b5a3ca9948b891d964f5 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Thu, 24 Oct 2024 21:03:31 -0700 Subject: [PATCH 2/9] Feat (float): adding quant_tensor --- src/brevitas/quant_tensor/float_quant_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/brevitas/quant_tensor/float_quant_tensor.py b/src/brevitas/quant_tensor/float_quant_tensor.py index 459f0eec7..8db6fda90 100644 --- a/src/brevitas/quant_tensor/float_quant_tensor.py +++ b/src/brevitas/quant_tensor/float_quant_tensor.py @@ -109,6 +109,10 @@ def _pre_round_float_value(self): minifloat_value = minifloat_value / int_scale return minifloat_value + def int(self): + fx_value = torch.round(self._pre_round_float_value) + return fx_value + @property def is_valid(self): with torch.no_grad(): From d4c760e7f02aead125f22958402bae2629421044 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Fri, 25 Oct 2024 21:35:03 -0700 Subject: [PATCH 3/9] Feat (float): adding .int() to groupwise quant_tensor --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 4a99b0207..7f252eddf 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -150,6 +150,10 @@ def _pre_round_float_value(self): minifloat_value = minifloat_value / int_scale return minifloat_value + def int(self): + fx_value = torch.round(self._pre_round_float_value) + return fx_value + @property def is_valid(self): with torch.no_grad(): From 7d4b8fafbce7dbc74c14ef50fb060a59657128f9 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Fri, 25 Oct 2024 23:01:09 -0700 Subject: [PATCH 4/9] Pre-commit fixes --- src/brevitas/quant_tensor/groupwise_float_quant_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 7f252eddf..16f75c49e 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -153,7 +153,7 @@ def _pre_round_float_value(self): def int(self): fx_value = torch.round(self._pre_round_float_value) return fx_value - + @property def is_valid(self): with torch.no_grad(): From dc790f1cc7400b0a42b8a7a92477ef74e51e5839 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Fri, 25 Oct 2024 23:10:19 -0700 Subject: [PATCH 5/9] Feat (tests): adding tests for quant tensor attributes --- .../quant_tensor/test_quant_tensor.py | 70 +++++++++++++++++-- 1 file changed, 66 insertions(+), 4 deletions(-) diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index 6f6b4c7d2..cfe97325d 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -1,7 +1,9 @@ # Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause + from enum import Enum +import numpy as np from packaging import version import pytest import pytest_cases @@ -24,14 +26,40 @@ class Operator(Enum): MATMUL = 4 -def to_quant_tensor(input: torch.Tensor) -> IntQuantTensor: - mod = QuantIdentity(bit_width=8, return_quant_tensor=True) +def to_quant_tensor(input: torch.Tensor, bit_width=8) -> IntQuantTensor: + mod = QuantIdentity(bit_width=bit_width, return_quant_tensor=True) return mod(input) -def to_float_quant_tensor(input: torch.Tensor) -> FloatQuantTensor: +def to_float_quant_tensor( + input: torch.Tensor, + bit_width=8, + exponent_bit_width=4, + mantissa_bit_width=3) -> FloatQuantTensor: mod = QuantIdentity( - bit_width=8, return_quant_tensor=True, act_quant=Fp8e5m2OCPActPerTensorFloat) + bit_width=bit_width, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + return_quant_tensor=True, + act_quant=Fp8e5m2OCPActPerTensorFloat) + return mod(input) + + +def to_mx_quant_tensor( + input: torch.Tensor, + bit_width=8, + exponent_bit_width=4, + mantissa_bit_width=3, + group_size=32, + group_dim=1) -> FloatQuantTensor: + mod = QuantIdentity( + bit_width=bit_width, + group_size=group_size, + group_dim=group_dim, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width, + return_quant_tensor=True, + act_quant=MXFloat8e4m3Act) return mod(input) @@ -138,3 +166,37 @@ def test_minifloat(quant_class_key_vale): qx = q(x) # Check that minifloat doesn't raise error qx.minifloat() + + +def test_int_quant_tensor(bit_width=8): + limit = np.exp2(bit_width) - 1 + w = torch.randn(32, 1024) + q = to_quant_tensor(w, bit_width=bit_width) + i = q.int().float() + assert ((i.max() - i.min()) <= limit).all() + + +def test_float_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): + assert mantissa_bit_width + exponent_bit_width + 1 == bit_width + limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2) + w = torch.randn(32, 1024) + q = to_float_quant_tensor( + w, + bit_width=bit_width, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width) + i = q.int().float() + assert ((i.max() - i.min()) <= limit).all() + + +def test_mx_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): + assert mantissa_bit_width + exponent_bit_width + 1 == bit_width + limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2) + w = torch.randn(32, 1024) + q = to_mx_quant_tensor( + w, + bit_width=bit_width, + exponent_bit_width=exponent_bit_width, + mantissa_bit_width=mantissa_bit_width) + i = q.int().float() + assert ((i.max() - i.min()) <= limit).all() From e55a8532cd731dc6342286c641d0a56ef7e249a9 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Fri, 25 Oct 2024 23:23:16 -0700 Subject: [PATCH 6/9] Fix (cache): fixing quant_tensor.set(...) --- src/brevitas/utils/quant_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas/utils/quant_utils.py b/src/brevitas/utils/quant_utils.py index 6fd519b41..62290b1de 100644 --- a/src/brevitas/utils/quant_utils.py +++ b/src/brevitas/utils/quant_utils.py @@ -91,7 +91,7 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool) self.shape = quant_tensor.value.shape if metadata_only: self.value = None - self.quant_tensor = quant_tensor.set(value=None) + self.quant_tensor = quant_tensor.set(value_=None) else: self.quant_tensor = quant_tensor # torch.compile compatibility @@ -146,7 +146,7 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool): self.shape = quant_tensor.value.shape if metadata_only: self.value = None - self.quant_tensor = quant_tensor.set(value=None) + self.quant_tensor = quant_tensor.set(value_=None) else: self.quant_tensor = quant_tensor # torch.compile compatibility From c79378295cd15f4dac4c4d8f456bdf41ec8b5db4 Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Fri, 25 Oct 2024 23:24:14 -0700 Subject: [PATCH 7/9] Feat (test): adding caching to the testing --- .../quant_tensor/test_quant_tensor.py | 33 ++++++++++++++++--- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index cfe97325d..b4617eba7 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -15,7 +15,11 @@ from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant_tensor import FloatQuantTensor +from brevitas.quant_tensor import GroupwiseFloatQuantTensor from brevitas.quant_tensor import IntQuantTensor +from brevitas.utils.quant_utils import _CachedIO +from brevitas.utils.quant_utils import _CachedIOFloat +from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat class Operator(Enum): @@ -167,16 +171,20 @@ def test_minifloat(quant_class_key_vale): # Check that minifloat doesn't raise error qx.minifloat() - -def test_int_quant_tensor(bit_width=8): +@pytest.mark.parametrize("metadata_only", [True, False]) +def test_int_quant_tensor(metadata_only, bit_width=8): limit = np.exp2(bit_width) - 1 w = torch.randn(32, 1024) q = to_quant_tensor(w, bit_width=bit_width) i = q.int().float() assert ((i.max() - i.min()) <= limit).all() + # test caching works + cache = _CachedIO(q, metadata_only=metadata_only) + assert cache.bit_width == bit_width -def test_float_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): +@pytest.mark.parametrize("metadata_only", [True, False]) +def test_float_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): assert mantissa_bit_width + exponent_bit_width + 1 == bit_width limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2) w = torch.randn(32, 1024) @@ -185,11 +193,17 @@ def test_float_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_widt bit_width=bit_width, exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width) + # test that the integer API returns fixed point values in the right range i = q.int().float() assert ((i.max() - i.min()) <= limit).all() + # test caching works + cache = _CachedIOFloat(q, metadata_only=metadata_only) + assert cache.mantissa_bit_width == mantissa_bit_width + assert cache.exponent_bit_width == exponent_bit_width -def test_mx_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): +@pytest.mark.parametrize("metadata_only", [True, False]) +def test_mx_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3): assert mantissa_bit_width + exponent_bit_width + 1 == bit_width limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2) w = torch.randn(32, 1024) @@ -197,6 +211,15 @@ def test_mx_quant_tensor(bit_width=8, exponent_bit_width=4, mantissa_bit_width=3 w, bit_width=bit_width, exponent_bit_width=exponent_bit_width, - mantissa_bit_width=mantissa_bit_width) + mantissa_bit_width=mantissa_bit_width, + group_size=32, + group_dim=1) + # test that the integer API returns fixed point values in the right range i = q.int().float() assert ((i.max() - i.min()) <= limit).all() + # test caching works + cache = _CachedIOGroupwiseFloat(q, metadata_only=metadata_only) + assert cache.mantissa_bit_width == mantissa_bit_width + assert cache.exponent_bit_width == exponent_bit_width + assert cache.group_size == 32 + assert cache.group_dim == 1 From 1bfb3324515339cfc034e4bf3db504c00322a6ab Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Sat, 26 Oct 2024 17:12:35 -0700 Subject: [PATCH 8/9] Pre-commit fixes --- tests/brevitas/quant_tensor/test_quant_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index b4617eba7..9c0df2c1b 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -171,6 +171,7 @@ def test_minifloat(quant_class_key_vale): # Check that minifloat doesn't raise error qx.minifloat() + @pytest.mark.parametrize("metadata_only", [True, False]) def test_int_quant_tensor(metadata_only, bit_width=8): limit = np.exp2(bit_width) - 1 From 4fadb9dec9a7b8b34e4038b10abe1408b63e711c Mon Sep 17 00:00:00 2001 From: Ian Colbert Date: Mon, 28 Oct 2024 10:57:09 -0700 Subject: [PATCH 9/9] Fixing typing hint --- tests/brevitas/quant_tensor/test_quant_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/quant_tensor/test_quant_tensor.py b/tests/brevitas/quant_tensor/test_quant_tensor.py index 9c0df2c1b..ea833f3f2 100644 --- a/tests/brevitas/quant_tensor/test_quant_tensor.py +++ b/tests/brevitas/quant_tensor/test_quant_tensor.py @@ -55,7 +55,7 @@ def to_mx_quant_tensor( exponent_bit_width=4, mantissa_bit_width=3, group_size=32, - group_dim=1) -> FloatQuantTensor: + group_dim=1) -> GroupwiseFloatQuantTensor: mod = QuantIdentity( bit_width=bit_width, group_size=group_size,