From dedb9b86cc5b3a947b7eeb0c54445ead1b897dc9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 27 Oct 2023 15:00:07 +0100 Subject: [PATCH] Fix (ptq): fix for ptq_common --- .../imagenet_classification/ptq/ptq_common.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 3e355c503..cbb56e7f3 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -361,6 +361,14 @@ def kwargs_prefix(prefix, weight_kwargs): 'return_quant_tensor': False} # yapf: enable + quant_act_kwargs = {'act_quant': act_quant, 'return_quant_tensor': True} + # For potentially unsigned activations, we create a separate dict + unsigned_quant_act_kwargs = quant_act_kwargs.copy() + if uint_sym_act_for_unsigned_values: + # In case we support unsigned activation, the output of softmax can be unsigned + quant_mha_kwargs['attn_output_weights_signed'] = False + unsigned_quant_act_kwargs['signed'] = False + # Layerwise is basic quant kwargs + input_quant layerwise_quant_wbiol_kwargs = {**quant_wbiol_kwargs, 'input_quant': per_tensor_act_quant} @@ -374,16 +382,6 @@ def kwargs_prefix(prefix, weight_kwargs): torch.nn.ConvTranspose1d: (qnn.QuantConvTranspose1d, quant_wbiol_kwargs), torch.nn.ConvTranspose2d: (qnn.QuantConvTranspose2d, quant_wbiol_kwargs),} - act_quant_and_bit_width = {'act_quant': act_quant, 'bit_width': act_bit_width} - quant_act_kwargs = {**act_quant_and_bit_width, 'return_quant_tensor': True} - - # For potentially unsigned activations, we create a separate dict - unsigned_quant_act_kwargs = quant_act_kwargs.copy() - if uint_sym_act_for_unsigned_values: - # In case we support unsigned activation, the output of softmax can be unsigned - quant_mha_kwargs['attn_output_weights_signed'] = False - unsigned_quant_act_kwargs['signed'] = False - quant_act_map = { torch.nn.ReLU: (qnn.QuantReLU, { **unsigned_quant_act_kwargs}),