diff --git a/src/brevitas/export/common/handler/qcdq.py b/src/brevitas/export/common/handler/qcdq.py index 1aecaaeb4..a63b6b5b5 100644 --- a/src/brevitas/export/common/handler/qcdq.py +++ b/src/brevitas/export/common/handler/qcdq.py @@ -51,6 +51,13 @@ def assert_ge_zero(self, *args): assert bools +class DQCastMixin(DQMixin, ABC): + + @abstractmethod + def cast_fn(self, x, dtype): + pass + + class CDQMixin(DQMixin, ABC): @abstractmethod @@ -129,15 +136,21 @@ def prepare_for_export(self, module): for tm in module.tracked_module_list} # Get the first quant weight as representative quant_weight = module.tracked_module_list[0].quant_weight() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = quant_weight.scale + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['int_weights'] = int_weights self.symbolic_kwargs['bit_width'] = quant_weight.bit_width self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, quant_weight.bit_width) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( - quant_weight.scale, - quant_weight.zero_point, - quant_weight.bit_width, - module.is_signed) + scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed) else: self.symbolic_kwargs = None @@ -156,6 +169,10 @@ def symbolic_execution(self, x: Tensor): if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both input and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -200,12 +217,22 @@ def prepare_for_export(self, module): if module.is_quant_enabled: self.validate(module) self.symbolic_kwargs['bit_width'] = module.bit_width() + + # (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype + # of the scale and we cast it to float32. + # The original dtype is then restored during the forward pass + scale = module.scale() + self.scale_dtype = scale.dtype + if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs( - module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + scale, module.zero_point(), module.bit_width(), module.is_signed) self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs( - module.scale(), module.zero_point(), module.bit_width(), module.is_signed) + scale, module.zero_point(), module.bit_width(), module.is_signed) self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs( module.is_narrow_range, module.is_signed, module.bit_width()) + else: self.symbolic_kwargs = None @@ -221,10 +248,17 @@ def symbolic_execution(self, x: Tensor): bit_width = self.symbolic_kwargs['bit_width'] # Workaround to trick the tracer into believing all return values are used self.assert_ge_zero(scale, zero_point, bit_width) + # If original dtype of the input is (b)float16, cast the input to float32 + if x.dtype == torch.float16 or x.dtype == torch.bfloat16: + x = self.cast_fn(x, torch.float32) x = self.quantize_fn(x, *quantize_symbolic_kwargs.values()) if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values()) + # After dequantization, cast both output and scale to the correct dtype + if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16: + x = self.cast_fn(x, self.scale_dtype) + scale = self.cast_fn(scale, self.scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -275,7 +309,16 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None): zero_point = to_0dim_if_scalar(zero_point).expand_as(scale) zero_point = self.zero_point_with_dtype( True, bit_width, zero_point) # assume signed is True + # If original dtype of scale is (b)float16, store the original dtype + # and cast the scale to float32 + scale_dtype = scale.dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + scale = self.cast_fn(scale, torch.float32) y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis) + # After dequantization, cast both output and scale to the correct dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + y = self.cast_fn(y, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) # Restore the original shapes to guarantee correct shape propagation downstream scale = scale.view(scale_orig_shape) zero_point = zero_point.view_as(scale) @@ -302,6 +345,13 @@ def symbolic_execution( output_bit_width = self.symbolic_kwargs['output_bit_width'] dtype = self.int8_dtype() if signed else self.uint8_dtype() trunc_scale = 2.0 ** (input_bit_width - output_bit_width) + # If original dtype of scale is (b)float16, store the original scale dtype + # and cast the scale and the input to float32 + scale_dtype = scale.dtype + if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16: + scale = self.cast_fn(scale, torch.float32) + if x.dtype == torch.bfloat16 or x.dtype == torch.float16: + x = self.cast_fn(x, torch.float32) pre_scale = scale * trunc_scale flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten()) flat_scale = to_0dim_if_scalar(scale.flatten()) @@ -312,4 +362,8 @@ def symbolic_execution( if clip_symbolic_kwargs is not None: x = self.clip_fn(x, *clip_symbolic_kwargs.values()) x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale)) + # After dequantization, cast both output and scale to the correct dtype + if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16: + x = self.cast_fn(x, scale_dtype) + scale = self.cast_fn(scale, scale_dtype) return x, scale, zero_point, output_bit_width diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index 1a1de5b15..e19ab4708 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -1,12 +1,19 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import onnx +import torch from torch.autograd import Function from brevitas.export.onnx import onnx_export_opset AXIS_OPSET = 13 +DATATYPE_DICT = { + torch.float32: onnx.TensorProto.DataType.FLOAT, + torch.float16: onnx.TensorProto.DataType.FLOAT16, + torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16} + class DequantizeLinearFn(Function): @@ -39,6 +46,18 @@ def forward(ctx, int_x, min_int_val, max_int_val): return int_x +class CastFn(Function): + + @staticmethod + def symbolic(g, x, dtype): + ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype]) + return ret + + @staticmethod + def forward(ctx, x, dtype): + return x.to(dtype) + + class QuantizeLinearFn(Function): @staticmethod diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 165f4ad6c..642ae9174 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -6,7 +6,7 @@ import torch from brevitas.export.common.handler.qcdq import CDQMixin -from brevitas.export.common.handler.qcdq import DQMixin +from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin @@ -17,16 +17,20 @@ from brevitas.export.onnx.handler import ONNXBaseHandler from brevitas.export.onnx.handler import QuantLSTMLayerHandler +from ..function import CastFn from ..function import DequantizeLinearFn from ..function import IntClipFn from ..function import QuantizeLinearFn -class StdDQONNXMixin(DQMixin, ABC): +class StdDQCastONNXMixin(DQCastMixin, ABC): def dequantize_fn(self, x, scale, zero_point, axis): return DequantizeLinearFn.apply(x, scale, zero_point, axis) + def cast_fn(self, x, dtype): + return CastFn.apply(x, dtype) + @property def flatten_dequantize_params(self): return True @@ -39,13 +43,13 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC): +class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC): def clip_fn(self, x, min_val, max_val): return IntClipFn.apply(x, min_val, max_val) -class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC): +class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC): @classmethod def int8_dtype(cls): @@ -70,42 +74,42 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) -class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin, - QCDQWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, + QCDQWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler( - StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): +class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( + StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin, - QCDQActQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, + QCDQActQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin, - QCDQBiasQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, + QCDQBiasQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin, - QCDQTruncQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, + QCDQTruncQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): +class StdQCDQCastONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): def quantized_cell_symbolic_execution( self, diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index ad2e58dff..ec712672a 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -14,13 +14,13 @@ from ..function import IntClipFn from ..function import QuantizeLinearFn from ..manager import StdONNXBaseManager -from .handler import StdQCDQONNXActQuantProxyHandler -from .handler import StdQCDQONNXBiasQuantProxyHandler -from .handler import StdQCDQONNXDecoupledWeightQuantProxyHandler -from .handler import StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler -from .handler import StdQCDQONNXQuantLSTMLayerHandler -from .handler import StdQCDQONNXTruncQuantProxyHandler -from .handler import StdQCDQONNXWeightQuantProxyHandler +from .handler import StdQCDQCastONNXActQuantProxyHandler +from .handler import StdQCDQCastONNXBiasQuantProxyHandler +from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler +from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler +from .handler import StdQCDQCastONNXQuantLSTMLayerHandler +from .handler import StdQCDQCastONNXTruncQuantProxyHandler +from .handler import StdQCDQCastONNXWeightQuantProxyHandler class StdQCDQONNXManager(StdONNXBaseManager): @@ -33,13 +33,13 @@ class StdQCDQONNXManager(StdONNXBaseManager): "eliminate_unused_initializer"] handlers = [ - StdQCDQONNXWeightQuantProxyHandler, - StdQCDQONNXBiasQuantProxyHandler, - StdQCDQONNXActQuantProxyHandler, - StdQCDQONNXDecoupledWeightQuantProxyHandler, - StdQCDQONNXTruncQuantProxyHandler, - StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler, - StdQCDQONNXQuantLSTMLayerHandler] + StdQCDQCastONNXWeightQuantProxyHandler, + StdQCDQCastONNXBiasQuantProxyHandler, + StdQCDQCastONNXActQuantProxyHandler, + StdQCDQCastONNXDecoupledWeightQuantProxyHandler, + StdQCDQCastONNXTruncQuantProxyHandler, + StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, + StdQCDQCastONNXQuantLSTMLayerHandler] custom_fns = [ DebugMarkerFunction, diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index e80e20dac..b3474a246 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -6,7 +6,7 @@ import torch from brevitas.export.common.handler.base import BaseHandler -from brevitas.export.common.handler.qcdq import DQMixin +from brevitas.export.common.handler.qcdq import DQCastMixin from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin @@ -23,7 +23,7 @@ def _itemize_clip_bounds(clip_args): return clip_args -class TorchDQMixin(DQMixin, ABC): +class TorchDQCastMixin(DQCastMixin, ABC): def __init__(self) -> None: super().__init__() @@ -40,6 +40,9 @@ def dequantize_fn(self, x, scale, zero_point, axis): zero_point = float(zero_point) return (x - zero_point) * scale + def cast_fn(self, x, dtype): + return x.type(dtype) + @property def flatten_dequantize_params(self): return False @@ -52,13 +55,13 @@ def validate(self, module): assert module.bit_width() > 1., 'Binary quant not supported' -class TorchCDQMixin(TorchDQMixin, ABC): +class TorchCDQCastMixin(TorchDQCastMixin, ABC): def clip_fn(self, x, min_val, max_val): return torch.clamp(x, min_val, max_val) -class TorchQCDQMixin(QMixin, TorchCDQMixin, ABC): +class TorchQCDQCastMixin(QMixin, TorchCDQCastMixin, ABC): @classmethod def int8_dtype(cls): @@ -90,9 +93,9 @@ def forward(self, *args, **kwargs): return self.symbolic_execution(*args, **kwargs) -class TorchQCDQWeightQuantProxyHandler(TorchCDQMixin, - QCDQWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, + QCDQWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -100,9 +103,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -110,8 +113,8 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQDecoupledWeightQuantWithInputProxyHandler( - TorchCDQMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): +class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler( + TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -119,8 +122,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQActQuantProxyHandler(TorchQCDQMixin, QCDQActQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin, + QCDQActQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -128,14 +132,15 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQBiasQuantProxyHandler(TorchDQMixin, QCDQBiasQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin, + QCDQBiasQuantProxyHandlerMixin, + TorchQCDQHandler): pass -class TorchQCDQTruncQuantProxyHandler(TorchQCDQMixin, - QCDQTruncQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin, + QCDQTruncQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): diff --git a/src/brevitas/export/torch/qcdq/manager.py b/src/brevitas/export/torch/qcdq/manager.py index 62f3bedb9..1da072a2d 100644 --- a/src/brevitas/export/torch/qcdq/manager.py +++ b/src/brevitas/export/torch/qcdq/manager.py @@ -11,24 +11,24 @@ from brevitas.export.manager import BaseManager from brevitas.export.manager import ExportContext -from .handler import TorchQCDQActQuantProxyHandler -from .handler import TorchQCDQBiasQuantProxyHandler -from .handler import TorchQCDQDecoupledWeightQuantProxyHandler -from .handler import TorchQCDQDecoupledWeightQuantWithInputProxyHandler -from .handler import TorchQCDQTruncQuantProxyHandler -from .handler import TorchQCDQWeightQuantProxyHandler +from .handler import TorchQCDQCastActQuantProxyHandler +from .handler import TorchQCDQCastBiasQuantProxyHandler +from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler +from .handler import TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler +from .handler import TorchQCDQCastTruncQuantProxyHandler +from .handler import TorchQCDQCastWeightQuantProxyHandler class TorchQCDQManager(BaseManager): target_name = 'torch' handlers = [ - TorchQCDQWeightQuantProxyHandler, - TorchQCDQDecoupledWeightQuantProxyHandler, - TorchQCDQDecoupledWeightQuantWithInputProxyHandler, - TorchQCDQActQuantProxyHandler, - TorchQCDQBiasQuantProxyHandler, - TorchQCDQTruncQuantProxyHandler] + TorchQCDQCastWeightQuantProxyHandler, + TorchQCDQCastDecoupledWeightQuantProxyHandler, + TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler, + TorchQCDQCastActQuantProxyHandler, + TorchQCDQCastBiasQuantProxyHandler, + TorchQCDQCastTruncQuantProxyHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool):