Skip to content

Commit

Permalink
Fix (scaling/standalone): better switch from runtime stats to param (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jan 7, 2025
1 parent 349057a commit 726ea3c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
12 changes: 8 additions & 4 deletions src/brevitas/core/scaling/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ def __init__(
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()

def init_scale(self):
if self.counter <= self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
self.counter = self.collect_stats_steps + 1

@brevitas.jit.script_method
def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
if self.counter < self.collect_stats_steps:
Expand All @@ -396,12 +402,10 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor:
self.counter = new_counter
return abs_binary_sign_grad(clamped_stats / threshold)
elif self.counter == self.collect_stats_steps:
self.restrict_inplace_preprocess(self.buffer)
inplace_tensor_mul(self.value.detach(), self.buffer)
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
self.init_scale()
value = self.clamp_scaling(self.restrict_scaling(self.value))
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
value = value / threshold
self.counter = self.counter + 1
return abs_binary_sign_grad(value)
else:
threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold))
Expand Down
9 changes: 3 additions & 6 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def restore_return_quant_tensor(model, previous_state):
def extend_collect_stats_steps(module):
if hasattr(module, 'collect_stats_steps'):
# We extend the collect steps in PTQ to match potentially long calibrations
module.collect_stats_steps = sys.maxsize
module.collect_stats_steps = sys.maxsize - 1


def set_collect_stats_to_average(module):
Expand All @@ -75,11 +75,8 @@ def set_collect_stats_to_average(module):


def finalize_collect_stats(module):
if hasattr(module, 'collect_stats_steps') and hasattr(module, 'counter'):
# If the counter has already reached collect_stats_steps, we do not want to reset it
# otherwise the restrict_preprocess might be applied twice: during calibration
# (that happens in training mode) and then when the model is evaluated
module.counter = max(module.collect_stats_steps, module.counter)
if hasattr(module, 'init_scale'):
module.init_scale()


class calibration_mode:
Expand Down
2 changes: 2 additions & 0 deletions tests/brevitas/export/test_torch_qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

@requires_pt_ge('1.9.1')
@jit_disabled_for_export()
@torch.no_grad()
def test_torch_qcdq_wbiol_export(
quant_module,
quant_module_impl,
Expand Down Expand Up @@ -57,6 +58,7 @@ def test_torch_qcdq_wbiol_export(
@requires_pt_ge('1.9.1')
@jit_disabled_for_export()
@parametrize('input_signed', [True, False])
@torch.no_grad()
def test_torch_qcdq_avgpool_export(input_signed, output_bit_width):
in_size = (1, IN_CH, FEATURES, FEATURES)
inp = torch.randn(in_size)
Expand Down

0 comments on commit 726ea3c

Please sign in to comment.