From 726ea3cbf4bd24cd09752fac6af36f99a2f428f8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 7 Jan 2025 11:29:56 +0100 Subject: [PATCH] Fix (scaling/standalone): better switch from runtime stats to param (#1099) --- src/brevitas/core/scaling/standalone.py | 12 ++++++++---- src/brevitas/graph/calibrate.py | 9 +++------ tests/brevitas/export/test_torch_qcdq.py | 2 ++ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/brevitas/core/scaling/standalone.py b/src/brevitas/core/scaling/standalone.py index 26bd12c9a..703fed5a4 100644 --- a/src/brevitas/core/scaling/standalone.py +++ b/src/brevitas/core/scaling/standalone.py @@ -375,6 +375,12 @@ def __init__( self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module() self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module() + def init_scale(self): + if self.counter <= self.collect_stats_steps: + self.restrict_inplace_preprocess(self.buffer) + inplace_tensor_mul(self.value.detach(), self.buffer) + self.counter = self.collect_stats_steps + 1 + @brevitas.jit.script_method def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: if self.counter < self.collect_stats_steps: @@ -396,12 +402,10 @@ def training_forward(self, stats_input: Tensor, threshold: Tensor) -> Tensor: self.counter = new_counter return abs_binary_sign_grad(clamped_stats / threshold) elif self.counter == self.collect_stats_steps: - self.restrict_inplace_preprocess(self.buffer) - inplace_tensor_mul(self.value.detach(), self.buffer) - threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) + self.init_scale() value = self.clamp_scaling(self.restrict_scaling(self.value)) + threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) value = value / threshold - self.counter = self.counter + 1 return abs_binary_sign_grad(value) else: threshold = self.restrict_threshold(self.restrict_threshold_pre(threshold)) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 9c753952e..a9c090c47 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -65,7 +65,7 @@ def restore_return_quant_tensor(model, previous_state): def extend_collect_stats_steps(module): if hasattr(module, 'collect_stats_steps'): # We extend the collect steps in PTQ to match potentially long calibrations - module.collect_stats_steps = sys.maxsize + module.collect_stats_steps = sys.maxsize - 1 def set_collect_stats_to_average(module): @@ -75,11 +75,8 @@ def set_collect_stats_to_average(module): def finalize_collect_stats(module): - if hasattr(module, 'collect_stats_steps') and hasattr(module, 'counter'): - # If the counter has already reached collect_stats_steps, we do not want to reset it - # otherwise the restrict_preprocess might be applied twice: during calibration - # (that happens in training mode) and then when the model is evaluated - module.counter = max(module.collect_stats_steps, module.counter) + if hasattr(module, 'init_scale'): + module.init_scale() class calibration_mode: diff --git a/tests/brevitas/export/test_torch_qcdq.py b/tests/brevitas/export/test_torch_qcdq.py index 6019bf417..6333be174 100644 --- a/tests/brevitas/export/test_torch_qcdq.py +++ b/tests/brevitas/export/test_torch_qcdq.py @@ -13,6 +13,7 @@ @requires_pt_ge('1.9.1') @jit_disabled_for_export() +@torch.no_grad() def test_torch_qcdq_wbiol_export( quant_module, quant_module_impl, @@ -57,6 +58,7 @@ def test_torch_qcdq_wbiol_export( @requires_pt_ge('1.9.1') @jit_disabled_for_export() @parametrize('input_signed', [True, False]) +@torch.no_grad() def test_torch_qcdq_avgpool_export(input_signed, output_bit_width): in_size = (1, IN_CH, FEATURES, FEATURES) inp = torch.randn(in_size)