Skip to content

Commit

Permalink
Fix (scale): correct output scale compute (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Nov 15, 2024
1 parent 552d24f commit c3208cf
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
10 changes: 8 additions & 2 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.torch_utils import compute_channel_view_shape
from brevitas.utils.torch_utils import is_broadcastable

from .quant_proxy import QuantProxyFromInjector
from .quant_proxy import QuantProxyProtocol
Expand Down Expand Up @@ -309,7 +310,12 @@ def quant_output_scale_impl(
channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim)
output_scale = weight.scale.view(output_scale_shape)
output_scale = output_scale * input.scale.view(output_scale_shape)

input_scale_view = input.scale.view(output_scale_shape)
if not is_broadcastable(output_scale.shape, input_scale_view.shape):
return None

output_scale = output_scale * input_scale_view
return output_scale

def compute_bias_scale(
Expand All @@ -336,8 +342,8 @@ def forward(
weight: Optional[Union[Tensor,
IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]:
out = x
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
input_scale = self.compute_bias_scale(input, weight)
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None and self.is_quant_enabled:
input_scale = self.scale()
Expand Down
9 changes: 7 additions & 2 deletions src/brevitas/quant_tensor/int_torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.utils.torch_utils import compute_channel_view_shape
from brevitas.utils.torch_utils import is_broadcastable

INT_QUANT_TENSOR_FN_HANDLER = {}

Expand Down Expand Up @@ -198,6 +199,9 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs):
(quant_weight.zero_point != 0.0).any()):
warnings.warn("Computing zero point of output accumulator not supported yet.")
compute_output_quant_tensor = False
if output_scale is None:
warnings.warn("Could not compute output scale factor, returning Tensor")
compute_output_quant_tensor = False

if compute_output_quant_tensor:
if output_zero_point is None:
Expand Down Expand Up @@ -230,8 +234,9 @@ def quant_output_scale_impl(
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)

quant_weight_scale = quant_weight_scale.view(output_scale_shape)
if len(quant_input_scale.shape) == 0:
quant_input_scale = quant_input_scale.view(output_scale_shape)
quant_input_scale = quant_input_scale.view(output_scale_shape)
if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape):
return None

output_scale = quant_weight_scale * quant_input_scale
return output_scale
Expand Down
9 changes: 9 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding


def is_broadcastable(tensor, other):
for a, b in zip(tensor[::-1], other[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True

0 comments on commit c3208cf

Please sign in to comment.