Skip to content

Commit

Permalink
Fix (runtime_act): fix negative group_dim handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 15, 2025
1 parent a7efcba commit 8004832
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def forward(self, x: Tensor, unused_scale: Tensor = None) -> Tuple[Tensor]:

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
Expand Down Expand Up @@ -278,7 +278,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
return output_args
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def group_size(self):

def apply_input_view(self, x):
x = super().apply_input_view(x)
start_dim = self.group_dim if self.group_dim != -1 else -2
start_dim = start_dim = self.group_dim if self.group_dim > 0 else self.group_dim - 1
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None):
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output == ScalingPerOutputType.GROUP:
reduce_dim = group_dim + 1 if group_dim != -1 else -1
reduce_dim = group_dim + 1 if group_dim > 0 else group_dim
return reduce_dim

@value
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def expand(self):

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
group_dim = group_dim if group_dim != -1 else -2
group_dim = group_dim if group_dim > 0 else group_dim - 1
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def expand(self):

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
group_dim = group_dim if group_dim != -1 else -2
group_dim = group_dim if group_dim > 0 else group_dim - 1
size = list(value.shape)
assert size[group_dim] % group_size == 0, 'Input channel is not divisible by group size'
if compress:
Expand Down
5 changes: 2 additions & 3 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ def float_to_int_impl_to_enum(module):


def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_shape):
final_shape = dequant_shape
curr_shape = value_.shape
start_dim = group_dim if group_dim != -1 else -2
start_dim = group_dim if group_dim > 0 else group_dim - 1
new_value = value_.flatten(start_dim, start_dim + 1)
if scale_.shape != ():
new_scale = scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
Expand All @@ -237,7 +236,7 @@ def groupwise_dequant_expand(value_, scale_, zero_point_, group_dim, dequant_sha
# First, we compute how much we padded along the group_dim shape
# Then, we unbind the tensor along the group_dim shape, and drop the padded columns
# Finally, we stack the remaining tensors
unpadding_shape = final_shape[group_dim]
unpadding_shape = dequant_shape[group_dim]
residual = new_value.shape[group_dim] - unpadding_shape

if residual > 0:
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ def generate_quantizers(
elif input_quant_granularity == 'per_group':
q_scaled_quant = sym_input_quant.let(
**{
'group_dim': 2, 'group_size': input_group_size})
'group_dim': -1, 'group_size': input_group_size})
k_transposed_quant = sym_input_quant.let(
**{
'group_dim': 1, 'group_size': input_group_size})
'group_dim': -2, 'group_size': input_group_size})
v_quant = q_scaled_quant
attn_output_weights_quant = q_scaled_quant
else:
Expand Down

0 comments on commit 8004832

Please sign in to comment.