Skip to content

Commit

Permalink
Feat (export): (b)float16 support for qcdq export (#776)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Dec 21, 2023
1 parent d0c10a5 commit 7f2dbbf
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 73 deletions.
66 changes: 60 additions & 6 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ def assert_ge_zero(self, *args):
assert bools


class DQCastMixin(DQMixin, ABC):

@abstractmethod
def cast_fn(self, x, dtype):
pass


class CDQMixin(DQMixin, ABC):

@abstractmethod
Expand Down Expand Up @@ -129,15 +136,21 @@ def prepare_for_export(self, module):
for tm in module.tracked_module_list}
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['int_weights'] = int_weights
self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.bit_width)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
quant_weight.scale,
quant_weight.zero_point,
quant_weight.bit_width,
module.is_signed)
scale, quant_weight.zero_point, quant_weight.bit_width, module.is_signed)
else:
self.symbolic_kwargs = None

Expand All @@ -156,6 +169,10 @@ def symbolic_execution(self, x: Tensor):
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -200,12 +217,22 @@ def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['bit_width'] = module.bit_width()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = module.scale()
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
module.scale(), module.zero_point(), module.bit_width(), module.is_signed)
scale, module.zero_point(), module.bit_width(), module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.int_clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, module.bit_width())

else:
self.symbolic_kwargs = None

Expand All @@ -221,10 +248,17 @@ def symbolic_execution(self, x: Tensor):
bit_width = self.symbolic_kwargs['bit_width']
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, zero_point, bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both output and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand Down Expand Up @@ -275,7 +309,16 @@ def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
zero_point = to_0dim_if_scalar(zero_point).expand_as(scale)
zero_point = self.zero_point_with_dtype(
True, bit_width, zero_point) # assume signed is True
# If original dtype of scale is (b)float16, store the original dtype
# and cast the scale to float32
scale_dtype = scale.dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
scale = self.cast_fn(scale, torch.float32)
y = self.dequantize_fn(int_bias, scale, zero_point, quant_axis)
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
y = self.cast_fn(y, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
Expand All @@ -302,6 +345,13 @@ def symbolic_execution(
output_bit_width = self.symbolic_kwargs['output_bit_width']
dtype = self.int8_dtype() if signed else self.uint8_dtype()
trunc_scale = 2.0 ** (input_bit_width - output_bit_width)
# If original dtype of scale is (b)float16, store the original scale dtype
# and cast the scale and the input to float32
scale_dtype = scale.dtype
if scale_dtype == torch.bfloat16 or scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x = self.cast_fn(x, torch.float32)
pre_scale = scale * trunc_scale
flat_pre_scale = to_0dim_if_scalar(pre_scale.flatten())
flat_scale = to_0dim_if_scalar(scale.flatten())
Expand All @@ -312,4 +362,8 @@ def symbolic_execution(
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, flat_scale, zp, self.quant_axis(scale))
# After dequantization, cast both output and scale to the correct dtype
if scale_dtype == torch.float16 or scale_dtype == torch.bfloat16:
x = self.cast_fn(x, scale_dtype)
scale = self.cast_fn(scale, scale_dtype)
return x, scale, zero_point, output_bit_width
19 changes: 19 additions & 0 deletions src/brevitas/export/onnx/standard/function.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import onnx
import torch
from torch.autograd import Function

from brevitas.export.onnx import onnx_export_opset

AXIS_OPSET = 13

DATATYPE_DICT = {
torch.float32: onnx.TensorProto.DataType.FLOAT,
torch.float16: onnx.TensorProto.DataType.FLOAT16,
torch.bfloat16: onnx.TensorProto.DataType.BFLOAT16}


class DequantizeLinearFn(Function):

Expand Down Expand Up @@ -39,6 +46,18 @@ def forward(ctx, int_x, min_int_val, max_int_val):
return int_x


class CastFn(Function):

@staticmethod
def symbolic(g, x, dtype):
ret = g.op('Cast', x, to_i=DATATYPE_DICT[dtype])
return ret

@staticmethod
def forward(ctx, x, dtype):
return x.to(dtype)


class QuantizeLinearFn(Function):

@staticmethod
Expand Down
48 changes: 26 additions & 22 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from brevitas.export.common.handler.qcdq import CDQMixin
from brevitas.export.common.handler.qcdq import DQMixin
from brevitas.export.common.handler.qcdq import DQCastMixin
from brevitas.export.common.handler.qcdq import QCDQActQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQBiasQuantProxyHandlerMixin
from brevitas.export.common.handler.qcdq import QCDQDecoupledWeightQuantProxyHandlerMixin
Expand All @@ -17,16 +17,20 @@
from brevitas.export.onnx.handler import ONNXBaseHandler
from brevitas.export.onnx.handler import QuantLSTMLayerHandler

from ..function import CastFn
from ..function import DequantizeLinearFn
from ..function import IntClipFn
from ..function import QuantizeLinearFn


class StdDQONNXMixin(DQMixin, ABC):
class StdDQCastONNXMixin(DQCastMixin, ABC):

def dequantize_fn(self, x, scale, zero_point, axis):
return DequantizeLinearFn.apply(x, scale, zero_point, axis)

def cast_fn(self, x, dtype):
return CastFn.apply(x, dtype)

@property
def flatten_dequantize_params(self):
return True
Expand All @@ -39,13 +43,13 @@ def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'


class StdCDQONNXMixin(CDQMixin, StdDQONNXMixin, ABC):
class StdCDQCastONNXMixin(CDQMixin, StdDQCastONNXMixin, ABC):

def clip_fn(self, x, min_val, max_val):
return IntClipFn.apply(x, min_val, max_val)


class StdQCDQONNXMixin(QMixin, StdCDQONNXMixin, ABC):
class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

@classmethod
def int8_dtype(cls):
Expand All @@ -70,42 +74,42 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


class StdQCDQONNXActQuantProxyHandler(StdQCDQONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXBiasQuantProxyHandler(StdDQONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXTruncQuantProxyHandler(StdQCDQONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):
class StdQCDQCastONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):

def quantized_cell_symbolic_execution(
self,
Expand Down
28 changes: 14 additions & 14 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdQCDQONNXActQuantProxyHandler
from .handler import StdQCDQONNXBiasQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQONNXQuantLSTMLayerHandler
from .handler import StdQCDQONNXTruncQuantProxyHandler
from .handler import StdQCDQONNXWeightQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXBiasQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
from .handler import StdQCDQCastONNXWeightQuantProxyHandler


class StdQCDQONNXManager(StdONNXBaseManager):
Expand All @@ -33,13 +33,13 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"eliminate_unused_initializer"]

handlers = [
StdQCDQONNXWeightQuantProxyHandler,
StdQCDQONNXBiasQuantProxyHandler,
StdQCDQONNXActQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantProxyHandler,
StdQCDQONNXTruncQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQONNXQuantLSTMLayerHandler]
StdQCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
Expand Down
Loading

0 comments on commit 7f2dbbf

Please sign in to comment.