Skip to content

Commit

Permalink
Feat (export): qonnx minifloat export (#1070)
Browse files Browse the repository at this point in the history
Giuseppe5 authored Nov 7, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 52e0059 commit d7d88c6
Showing 4 changed files with 168 additions and 1 deletion.
49 changes: 49 additions & 0 deletions src/brevitas/export/onnx/qonnx/function.py
Original file line number Diff line number Diff line change
@@ -59,6 +59,55 @@ def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding
return y


class BrevitasFloatQuantFn(Function):

@staticmethod
def symbolic(
g,
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
ret = g.op(
f'{DOMAIN_STRING}::FloatQuant',
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
max_val,
has_inf_i=int(has_inf),
has_nan_i=int(has_nan),
has_subnormal_i=int(has_subnormal),
rounding_mode_s=rounding_mode,
saturation_i=saturating)
ret.setType(x.type())
return ret

@staticmethod
def forward(
g,
x,
scale,
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
has_inf,
has_nan,
saturating,
has_subnormal,
rounding_mode,
max_val):
return x


class BrevitasTruncFn(Function):

@staticmethod
102 changes: 102 additions & 0 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
@@ -14,14 +14,116 @@
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.proxy.float_parameter_quant import WeightFloatQuantProxyFromInjector
from brevitas.proxy.float_runtime_quant import ActFloatQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector

from .function import BrevitasBinaryQuantFn
from .function import BrevitasFloatQuantFn
from .function import BrevitasQuantFn
from .function import BrevitasQuantLSTMCellFn
from .function import BrevitasTruncFn


class BrevitasFloatQuantProxyHandler(ONNXBaseHandler, ABC):

def validate(self, module):
assert not module.is_groupwise, "Export with Per Group quantization not supported"

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs = {
'scale':
module.scale(),
'exponent_bit_width':
module.exponent_bit_width(),
'mantissa_bit_width':
module.mantissa_bit_width(),
'exponent_bias':
module.exponent_bias(),
'has_inf':
module.inf_values() is not None,
'has_nan':
module.nan_values() is not None,
'saturating':
module.is_saturating(),
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())}
self.return_args = {
'scale': module.scale(),
'zero_point': torch.zeros_like(module.scale()),
'exponent_bit_width': module.exponent_bit_width(),
'mantissa_bit_width': module.mantissa_bit_width(),
'exponent_bias': module.exponent_bias(),
'saturating': module.is_saturating(),
'inf_values': module.inf_values(),
'nan_values': module.nan_values(),}

def symbolic_execution(self, x: Tensor):
x = BrevitasFloatQuantFn.apply(x, *self.symbolic_kwargs.values())
return_args = (x, *self.return_args.values())
return return_args


class BrevitasWeightFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler):
handled_layer = WeightFloatQuantProxyFromInjector

def __init__(self):
super().__init__()
self.quant_weights = None

def validate(self, zero_point):
assert zero_point == 0, "Zero-point not supported for minifloat quant."

def prepare_for_export(self, module: WeightQuantProxyFromInjector):
if module.is_quant_enabled:
first_qweight = module.tracked_module_list[0].quant_weight()
self.validate(first_qweight.zero_point)
self.symbolic_kwargs = {
'scale':
first_qweight.scale,
'exponent_bit_width':
first_qweight.exponent_bit_width,
'mantissa_bit_width':
first_qweight.mantissa_bit_width,
'exponent_bias':
first_qweight.exponent_bias,
'has_inf':
first_qweight.inf_values is not None,
'has_nan':
first_qweight.nan_values is not None,
'saturating':
first_qweight.saturating,
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float
).type_as(first_qweight.scale)}
self.return_args = {
'scale': first_qweight.scale,
'zero_point': torch.zeros_like(first_qweight.scale),
'exponent_bit_width': first_qweight.exponent_bit_width,
'mantissa_bit_width': first_qweight.mantissa_bit_width,
'exponent_bias': first_qweight.exponent_bias,
'saturating': first_qweight.saturating,
'inf_values': first_qweight.inf_values,
'nan_values': first_qweight.nan_values,}

def symbolic_execution(self, x: Tensor):
return super().symbolic_execution(x)


class BrevitasActFloatQuantProxyHandler(BrevitasFloatQuantProxyHandler):
handled_layer = ActFloatQuantProxyFromInjector


class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC):

def validate(self, module):
6 changes: 5 additions & 1 deletion src/brevitas/export/onnx/qonnx/manager.py
Original file line number Diff line number Diff line change
@@ -17,12 +17,14 @@
from .function import BrevitasQuantFn
from .function import BrevitasQuantLSTMCellFn
from .function import BrevitasTruncFn
from .handler import BrevitasActFloatQuantProxyHandler
from .handler import BrevitasActQuantProxyHandler
from .handler import BrevitasBiasQuantProxyHandler
from .handler import BrevitasDecoupledWeightQuantProxyHandler
from .handler import BrevitasDecoupledWeightQuantWithInputProxyHandler
from .handler import BrevitasQuantLSTMLayerHandler
from .handler import BrevitasTruncQuantProxyHandler
from .handler import BrevitasWeightFloatQuantProxyHandler
from .handler import BrevitasWeightQuantProxyHandler


@@ -42,7 +44,9 @@ class QONNXManager(ONNXBaseManager):
BrevitasDecoupledWeightQuantProxyHandler,
BrevitasDecoupledWeightQuantWithInputProxyHandler,
BrevitasTruncQuantProxyHandler,
BrevitasQuantLSTMLayerHandler]
BrevitasQuantLSTMLayerHandler,
BrevitasWeightFloatQuantProxyHandler,
BrevitasActFloatQuantProxyHandler]

custom_fns = [
DebugMarkerFunction,
12 changes: 12 additions & 0 deletions tests/brevitas/export/test_onnx_fp8.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

from brevitas import torch_version
from brevitas.export import export_onnx_qcdq
from brevitas.export import export_qonnx
import brevitas.nn as qnn
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
@@ -23,6 +24,17 @@ def test_simple_fp8_export():
assert True


@jit_disabled_for_export()
def test_qonnx_simple_fp8_export():
if torch_version < version.parse('2.1.0'):
pytest.skip(f"OCP FP8 types not supported by {torch_version}")

model = qnn.QuantLinear(
3, 16, weight_quant=Fp8e4m3OCPWeightPerTensorFloat, input_quant=Fp8e4m3OCPActPerTensorFloat)
export_qonnx(model, torch.randn(1, 3), 'qonnx_act_weight_fp8.onnx')
assert True


@jit_disabled_for_export()
def test_fp8_export_activation():
if torch_version < version.parse('2.1.0'):

0 comments on commit d7d88c6

Please sign in to comment.