diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index e1ba13a00..44457d264 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -167,12 +167,10 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]): if not self.training and not self._export_mode and self.cache_inference_quant_inp: cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) self._cached_inp = cached_inp - else: - # inp = QuantTensor(inp, scale=torch.tensor(1.0, device=inp.device, dtype=inp.dtype), training=self.training) - if not self.training and self.cache_inference_quant_inp: - cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) - self._cached_inp = cached_inp - # print(inp) + # else: + # if not self.training and self.cache_inference_quant_inp: + # cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only) + # self._cached_inp = cached_inp # Remove any naming metadata to avoid dowmstream errors # Avoid inplace operations on the input in case of forward hooks if not torch._C._get_tracing_state(): diff --git a/src/brevitas/quant_tensor/__init__.py b/src/brevitas/quant_tensor/__init__.py index 3423bc212..401fa06dd 100644 --- a/src/brevitas/quant_tensor/__init__.py +++ b/src/brevitas/quant_tensor/__init__.py @@ -8,6 +8,7 @@ import torch from torch import Tensor +import brevitas.config as config from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.function.ops_ste import ceil_ste @@ -97,7 +98,10 @@ def __new__( @classmethod def from_fake_quantized(cls, fake_quant_value, scale, zero_point, bit_width, signed, training): - quant_tensor = torch.round(fake_quant_value / scale + zero_point) + if config._ONGOING_EXPORT: + quant_tensor = fake_quant_value + else: + quant_tensor = torch.round(fake_quant_value / scale + zero_point) return cls(quant_tensor, scale, zero_point, bit_width, signed, training) @property @@ -131,7 +135,7 @@ def tensor(self): @property def value(self): - if self.is_valid: + if self.is_valid and not config._ONGOING_EXPORT: if self.zero_point is None or self.scale is None: return self.qt_value return (self.qt_value - self.zero_point) * self.scale @@ -232,7 +236,16 @@ def contiguous(self): def int(self, float_datatype=False): if self.is_valid: - return self.qt_value + int_value = self.qt_value + if float_datatype: + return int_value + else: + if self.bit_width <= 8. and self.signed_t.item(): + return int_value.to(torch.int8) + elif self.bit_width <= 8. and not self.signed_t.item(): + return int_value.to(torch.uint8) + else: + return int_value.to(torch.int32) else: raise RuntimeError(f"QuantTensor not valid.") diff --git a/tests/brevitas_ort/common.py b/tests/brevitas_ort/common.py index c05fd59b9..735d23eec 100644 --- a/tests/brevitas_ort/common.py +++ b/tests/brevitas_ort/common.py @@ -132,10 +132,13 @@ def is_brevitas_ort_close( brevitas_output = brevitas_output.int(float_datatype=False) elif export_type == 'qcdq': export_onnx_qcdq(model, input_t, export_path=export_name) + brevitas_output = brevitas_output.value elif export_type == 'qcdq_opset14': export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name) + brevitas_output = brevitas_output.value elif export_type == 'qonnx_opset14': export_qonnx(model, input_t, opset_version=14, export_path=export_name) + brevitas_output = brevitas_output.value else: raise RuntimeError(f"Export type {export_type} not recognized.")