Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (export/qcdq): add handler for explicit weight Q #731

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,24 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class QCDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
class QCDQProxyHandlerMixin(CDQProxyHandlerMixin, ABC):

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}


class CDQWeightQuantProxyHandlerMixin(CDQProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
Expand Down Expand Up @@ -162,7 +179,44 @@ def symbolic_execution(self, x: Tensor):
return x, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantProxyHandlerMixin(QCDQWeightQuantProxyHandlerMixin, ABC):
class QCDQWeightQuantProxyHandlerMixin(QCDQProxyHandlerMixin, ABC):
handled_layer = WeightQuantProxyFromInjector

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
else:
self.symbolic_kwargs = None

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
dequantize_symbolic_kwargs = self.symbolic_kwargs['dequantize_symbolic_kwargs']
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
bit_width = self.symbolic_kwargs['bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, bit_width


class CDQDecoupledWeightQuantProxyHandlerMixin(CDQWeightQuantProxyHandlerMixin, ABC):
handled_layer = DecoupledWeightQuantProxyFromInjector

def symbolic_execution(self, x: Tensor):
Expand All @@ -171,31 +225,17 @@ def symbolic_execution(self, x: Tensor):
return out, scale, zero_point, scale, zero_point, bit_width


class QCDQDecoupledWeightQuantWithInputProxyHandlerMixin(QCDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
class CDQDecoupledWeightQuantWithInputProxyHandlerMixin(CDQDecoupledWeightQuantProxyHandlerMixin,
ABC):
handled_layer = DecoupledWeightQuantWithInputProxyFromInjector

def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_signed: bool):
return super().symbolic_execution(x)


class QCDQActQuantProxyHandlerMixin(QMixin, CDQProxyHandlerMixin, ABC):
class QCDQActQuantProxyHandlerMixin(QMixin, QCDQProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

def quantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
Expand Down
17 changes: 12 additions & 5 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import torch

from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand Down Expand Up @@ -70,20 +71,26 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
class StdCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
CDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXWeightQuantProxyHandler(StdQCDQONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
StdCDQONNXMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


Expand Down
21 changes: 16 additions & 5 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import torch

from brevitas.export.common.handler.base import BaseHandler
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import CDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantWithInputProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQTruncQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQWeightQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QMixin
Expand Down Expand Up @@ -90,7 +91,17 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchQCDQWeightQuantProxyHandler(TorchCDQMixin,
class TorchCDQWeightQuantProxyHandler(TorchCDQMixin,
CDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQWeightQuantProxyHandler(TorchQCDQMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

Expand All @@ -101,7 +112,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
CDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
Expand All @@ -111,7 +122,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):


class TorchQCDQDecoupledWeightQuantWithInputProxyHandler(
TorchCDQMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):
TorchCDQMixin, CDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand Down
Loading