Skip to content

Commit

Permalink
Attempting fix for export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 26, 2024
1 parent f74b71c commit 3c9aa0c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
10 changes: 4 additions & 6 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,10 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
if not self.training and not self._export_mode and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
else:
# inp = QuantTensor(inp, scale=torch.tensor(1.0, device=inp.device, dtype=inp.dtype), training=self.training)
if not self.training and self.cache_inference_quant_inp:
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# print(inp)
# else:
# if not self.training and self.cache_inference_quant_inp:
# cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
# self._cached_inp = cached_inp
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
Expand Down
19 changes: 16 additions & 3 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch import Tensor

import brevitas.config as config
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int
from brevitas.function.ops_ste import ceil_ste
Expand Down Expand Up @@ -97,7 +98,10 @@ def __new__(

@classmethod
def from_fake_quantized(cls, fake_quant_value, scale, zero_point, bit_width, signed, training):
quant_tensor = torch.round(fake_quant_value / scale + zero_point)
if config._ONGOING_EXPORT:
quant_tensor = fake_quant_value
else:
quant_tensor = torch.round(fake_quant_value / scale + zero_point)
return cls(quant_tensor, scale, zero_point, bit_width, signed, training)

@property
Expand Down Expand Up @@ -131,7 +135,7 @@ def tensor(self):

@property
def value(self):
if self.is_valid:
if self.is_valid and not config._ONGOING_EXPORT:
if self.zero_point is None or self.scale is None:
return self.qt_value
return (self.qt_value - self.zero_point) * self.scale
Expand Down Expand Up @@ -232,7 +236,16 @@ def contiguous(self):

def int(self, float_datatype=False):
if self.is_valid:
return self.qt_value
int_value = self.qt_value
if float_datatype:
return int_value
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
elif self.bit_width <= 8. and not self.signed_t.item():
return int_value.to(torch.uint8)
else:
return int_value.to(torch.int32)
else:
raise RuntimeError(f"QuantTensor not valid.")

Expand Down
3 changes: 3 additions & 0 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def is_brevitas_ort_close(
brevitas_output = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.value
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
brevitas_output = brevitas_output.value
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
brevitas_output = brevitas_output.value
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

Expand Down

0 comments on commit 3c9aa0c

Please sign in to comment.