From 0d30ab1bc8e50b8bbe8da896c4d280458205a2e3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 15 Jan 2025 23:54:23 +0100 Subject: [PATCH] Feat (zero_point): dynamic groupwise zero point (#1160) --- src/brevitas/core/zero_point.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/brevitas/core/zero_point.py b/src/brevitas/core/zero_point.py index f74fffae8..bf9a6f1a3 100644 --- a/src/brevitas/core/zero_point.py +++ b/src/brevitas/core/zero_point.py @@ -344,3 +344,25 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: # pre-zero centering before rounding and clipping z = self.get_zero_center(x) / scale # need to scale the norm by s return z + + +class RuntimeDynamicGroupZeroPoint(brevitas.jit.ScriptModule): + + def __init__( + self, + input_view_impl: Module, + zero_point_stats_impl: Module, + int_quant: Module, + quantize_zero_point: bool) -> None: + super(RuntimeDynamicGroupZeroPoint, self).__init__() + + self.zero_point_stats_impl = zero_point_stats_impl + self.input_view_impl = input_view_impl + self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point) + + @brevitas.jit.script_method + def forward(self, stats_input: torch.Tensor, scale, bit_width) -> torch.Tensor: + + stats_input_reshaped = self.input_view_impl(stats_input) + out = self.zero_point_stats_impl(stats_input_reshaped) + return self.scale_shift_zero_point(-out, scale, bit_width)