Skip to content

Commit

Permalink
Fix Qop export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 26, 2024
1 parent 3c9aa0c commit a5d76b2
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def input_quant_symbolic_kwargs(cls, module):

@classmethod
def input_dequant_symbolic_kwargs(cls, module):
if module._cached_inp.scale is not None:
if module._cached_inp is not None:
return cls.dequant_symbolic_kwargs_from_cached_io(module._cached_inp)
else:
return None
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe

# shortcut execution through the export impl during export
if self.export_mode:
out = self.export_handler(inp.value)
inp_value = getattr(inp, 'value', inp)
out = self.export_handler(inp_value)
self._set_global_is_quant_layer(False)
return out

Expand Down
14 changes: 6 additions & 8 deletions tests/brevitas_ort/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def is_brevitas_ort_close(
model, np_input, export_name, export_type, tolerance=None, first_output_only=False):
input_t = torch.from_numpy(np_input)
brevitas_output = model(input_t)
computed_out = brevitas_output.value

if tolerance is not None and export_type == 'qcdq':
tolerance = tolerance * brevitas_output.scale # Float Output, tolerance is +/- output scale
Expand All @@ -129,16 +130,13 @@ def is_brevitas_ort_close(
else:
if export_type == 'qop':
export_onnx_qop(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.int(float_datatype=False)
computed_out = brevitas_output.int(float_datatype=False)
elif export_type == 'qcdq':
export_onnx_qcdq(model, input_t, export_path=export_name)
brevitas_output = brevitas_output.value
elif export_type == 'qcdq_opset14':
export_onnx_qcdq(model, input_t, opset_version=14, export_path=export_name)
brevitas_output = brevitas_output.value
elif export_type == 'qonnx_opset14':
export_qonnx(model, input_t, opset_version=14, export_path=export_name)
brevitas_output = brevitas_output.value
else:
raise RuntimeError(f"Export type {export_type} not recognized.")

Expand All @@ -147,13 +145,13 @@ def is_brevitas_ort_close(
if first_output_only:
if isinstance(ort_output, (tuple, list)):
ort_output = ort_output[0]
if isinstance(brevitas_output, tuple):
brevitas_output = brevitas_output[0]
if isinstance(computed_out, tuple):
computed_out = computed_out[0]
# make sure we are not comparing 0s
if (ort_output == 0).all() and (brevitas_output == 0).all():
if (ort_output == 0).all() and (computed_out == 0).all():
pytest.skip("Skip testing against all 0s.")

return recursive_allclose(ort_output, brevitas_output, tolerance)
return recursive_allclose(ort_output, computed_out, tolerance)


def gen_linspaced_data(num_samples, min_val=-1.0, max_val=1.0):
Expand Down

0 comments on commit a5d76b2

Please sign in to comment.