Skip to content

Commit

Permalink
Feat (zero_point): dynamic groupwise zero point (#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 15, 2025
1 parent 9721379 commit 0d30ab1
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/brevitas/core/zero_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,25 @@ def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
# pre-zero centering before rounding and clipping
z = self.get_zero_center(x) / scale # need to scale the norm by s
return z


class RuntimeDynamicGroupZeroPoint(brevitas.jit.ScriptModule):

def __init__(
self,
input_view_impl: Module,
zero_point_stats_impl: Module,
int_quant: Module,
quantize_zero_point: bool) -> None:
super(RuntimeDynamicGroupZeroPoint, self).__init__()

self.zero_point_stats_impl = zero_point_stats_impl
self.input_view_impl = input_view_impl
self.scale_shift_zero_point = _ScaleShiftZeroPoint(int_quant, quantize_zero_point)

@brevitas.jit.script_method
def forward(self, stats_input: torch.Tensor, scale, bit_width) -> torch.Tensor:

stats_input_reshaped = self.input_view_impl(stats_input)
out = self.zero_point_stats_impl(stats_input_reshaped)
return self.scale_shift_zero_point(-out, scale, bit_width)

0 comments on commit 0d30ab1

Please sign in to comment.