From e0cefdac14ae0e80cd39c609d503bb92f830e594 Mon Sep 17 00:00:00 2001 From: Alessandro Pappalardo Date: Fri, 20 Oct 2023 12:15:14 +0100 Subject: [PATCH] Feat (export/qcdq): add handler for explicit weight Q Signed-off-by: Alessandro Pappalardo --- src/brevitas/export/common/handler/qcdq.py | 78 ++++++++++++++----- .../export/onnx/standard/qcdq/handler.py | 17 ++-- src/brevitas/export/torch/qcdq/handler.py | 21 +++-- 3 files changed, 87 insertions(+), 29 deletions(-) diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 1aecaaeb4..30017d1bc 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -118,7 +118,24 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): 'scale_orig_shape': scale_orig_shape} -class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC): +class QCDQProxyHandlerMixin(CDQProxyHandlerMixin, ABC): + + def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): + # compute axis before redefining scale + axis = cls.quant_axis(scale) + scale = to_0dim_if_scalar(scale.flatten()) + zp = to_0dim_if_scalar(zero_point.flatten()) + # expand_as must go after 0-dim check + zp = zp.expand_as(scale) + zp = cls.zero_point_with_dtype(is_signed, bit_width, zp) + if cls.itemize_quantize_scalar_params: + scale = to_item_if_0dim(scale) + zp = to_item_if_0dim(zp) + dtype = cls.signed_dtype(bit_width, is_signed) + return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} + + +class CDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC): handled_layer = WeightQuantProxyFromInjector def prepare_for_export(self, module): @@ -162,7 +179,44 @@ def symbolic_execution(self, x: Tensor): return x, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC): +class QCDQWeightQuantProxyHandlerMixin(QCDQProxyHandlerMixin, ABC): + handled_layer = WeightQuantProxyFromInjector + + def prepare_for_export(self, module): + if module.is_quant_enabled: + self.validate(module) + self.symbolic_kwargs['bit_width'] = module.bit_width() + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( + module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs( + module.is_narrow_range, module.is_signed, module.bit_width()) + self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( + module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + else: + self.symbolic_kwargs = None + + def symbolic_execution(self, x: Tensor): + assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled' + quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs'] + clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs'] + dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs'] + scale = dequantize_symbolic_kwargs['scale'] + zero_point = dequantize_symbolic_kwargs['zero_point'] + bit_width = self.symbolic_kwargs['bit_width'] + scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape') + # Workaround to trick the tracer into believing all return values are used + self.assert_ge_zero(scale, zero_point, bit_width) + x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) + if clip_symbolic_kwargs is not None: + x = self.clip_fn(x, *clip_symbolic_kwargs.values()) + x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # Restore the original shapes to guarantee correct shape propagation downstream + scale = scale.view(scale_orig_shape) + zero_point = zero_point.view_as(scale) + return x, scale, zero_point, bit_width + + +class CDQDecoupledWeightQuantProxyHandlerMixin(CDQWeightQuantProxyHandlerMixin, ABC): handled_layer = DecoupledWeightQuantProxyFromInjector def symbolic_execution(self, x: Tensor): @@ -171,31 +225,17 @@ def symbolic_execution(self, x: Tensor): return out, scale, zero_point, scale, zero_point, bit_width -class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin, - ABC): +class CDQDecoupledWeightQuantWithInputProxyHandlerMixin(CDQDecoupledWeightQuantProxyHandlerMixin, + ABC): handled_layer = DecoupledWeightQuantWithInputProxyFromInjector def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool): return super().symbolic_execution(x) -class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC): +class QCDQActQuantProxyHandlerMixin(QMixin, QCDQProxyHandlerMixin, ABC): handled_layer = ActQuantProxyFromInjector - def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed): - # compute axis before redefining scale - axis = cls.quant_axis(scale) - scale = to_0dim_if_scalar(scale.flatten()) - zp = to_0dim_if_scalar(zero_point.flatten()) - # expand_as must go after 0-dim check - zp = zp.expand_as(scale) - zp = cls.zero_point_with_dtype(is_signed, bit_width, zp) - if cls.itemize_quantize_scalar_params: - scale = to_item_if_0dim(scale) - zp = to_item_if_0dim(zp) - dtype = cls.signed_dtype(bit_width, is_signed) - return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis} - def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 165f4ad6c..772e2b1c3 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -5,12 +5,13 @@ import torch +from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import CDQMixin +from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQMixin from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -70,20 +71,26 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) -class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin, +class StdCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin, + CDQWeightQuantProxyHandlerMixin, + ONNXBaseHandler): + pass + + +class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin, QCDQWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, + CDQDecoupledWeightQuantProxyHandlerMixin, ONNXBaseHandler): pass class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler( - StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): + StdCDQONNXMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): pass diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index e80e20dac..b79bfa171 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -6,11 +6,12 @@ import torch from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin +from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import DQMixin from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin -from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QMixin @@ -90,7 +91,17 @@ def forward(self, *args, **kwargs): return self.symbolic_execution(*args, **kwargs) -class TorchQCDQWeightQuantProxyHandler(TorchCDQMixin, +class TorchCDQWeightQuantProxyHandler(TorchCDQMixin, + CDQWeightQuantProxyHandlerMixin, + TorchQCDQHandler): + + @classmethod + def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): + clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width) + return _itemize_clip_bounds(clip_args) + + +class TorchQCDQWeightQuantProxyHandler(TorchQCDQMixin, QCDQWeightQuantProxyHandlerMixin, TorchQCDQHandler): @@ -101,7 +112,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, + CDQDecoupledWeightQuantProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -111,7 +122,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): class TorchQCDQDecoupledWeightQuantWithInputProxyHandler( - TorchCDQMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): + TorchCDQMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):