Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (mx): unpadding during dequantization #1134

Merged
merged 10 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non padding weights shape torch.Size([64, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 32, 3, 3])\n"
"Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 1, 32, 3, 3])\n"
]
}
],
Expand Down Expand Up @@ -257,8 +257,8 @@
"o = mx_model(x)\n",
"\n",
"# The quant weight of the padded model is different from the non padding one\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n",
"\n",
"# However, results are still the same \n",
"assert torch.allclose(o, o_no_padding)"
Expand Down
10 changes: 7 additions & 3 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.quant.experimental.mx_quant_ocp import GroupwiseActQuantProxyFromInjector
from brevitas.utils.quant_utils import groupwise_dequant_expand
from brevitas.utils.torch_utils import float_internal_scale


Expand Down Expand Up @@ -146,8 +147,8 @@ def __init__(self):
def prepare_for_export(self, module):
super().prepare_for_export(module)
if module.is_quant_enabled:
self.group_dim = module.group_dim
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
Expand All @@ -165,12 +166,13 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]
return out, scale, zero_point, self.bit_width


Expand Down Expand Up @@ -294,6 +296,7 @@ def prepare_for_export(self, module: nn.Module):
if module.is_quant_enabled:
self.input_view = module.input_view_impl
self.flattened_view = module.apply_input_view
self.group_dim = module.group_dim
if module._cached_weight is not None and not module.cache_inference_quant_weight_metadata_only:
self.cached_weight = module._cached_weight.quant_tensor.value_
else:
Expand All @@ -311,11 +314,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor]:
if self.cached_weight is not None:
out = self.cached_weight
else:
inp_shape = x.shape
x = self.input_view(x)
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)

# If we skip quant tensor, we return the flattened version of the groupwise tensor
if self.skip_create_quant_tensor:
out = self.flattened_view(out)
out = groupwise_dequant_expand(out, scale, zero_point, self.group_dim, inp_shape)[0]

return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
6 changes: 0 additions & 6 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,6 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight_shape = torch.tensor(layer.weight.shape)

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/proxy/float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[FloatQuantTensor] = None) -> FloatQuantTensor:
if x is None:
if isinstance(qt_args, tuple):
out = FloatQuantTensor(*qt_args, signed=self.is_signed, training=self.training)
else:
out = FloatQuantTensor(
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/groupwise_float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def apply_input_view(self, x):
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
shape = self.tracked_parameter_list[0].shape
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support weight quant sharing for groupwise anyway, so this is safe, but it is ugly.

out, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
return GroupwiseFloatQuantTensor(
out,
Expand All @@ -48,4 +49,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseFloatQuantTensor:
inf_values,
nan_values,
self.is_signed,
self.training)
self.training,
shape)
10 changes: 6 additions & 4 deletions src/brevitas/proxy/groupwise_float_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def apply_input_view(self, x):
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseFloatQuantTensor] = None) -> GroupwiseFloatQuantTensor:
if x is None:
x: Union[torch.Tensor, GroupwiseFloatQuantTensor]) -> GroupwiseFloatQuantTensor:
if isinstance(qt_args, tuple):
value, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values = qt_args
out = GroupwiseFloatQuantTensor(
value,
Expand All @@ -45,7 +45,8 @@ def create_quant_tensor(
inf_values,
nan_values,
self.is_signed,
self.training)
self.training,
dequant_shape=x.shape)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
else:
out = GroupwiseFloatQuantTensor(
qt_args,
Expand All @@ -60,5 +61,6 @@ def create_quant_tensor(
x.inf_values,
x.nan_values,
x.signed,
self.training)
self.training,
x.shape)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return out
4 changes: 3 additions & 1 deletion src/brevitas/proxy/groupwise_int_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def apply_input_view(self, x):
return x.flatten(start_dim, start_dim + 1)

def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
shape = self.tracked_parameter_list[0].shape
out, scale, zero_point, bit_width = qt_args
return GroupwiseIntQuantTensor(
out,
Expand All @@ -43,4 +44,5 @@ def create_quant_tensor(self, qt_args: Tuple[Any]) -> GroupwiseIntQuantTensor:
self.group_dim,
bit_width,
self.is_signed,
self.training)
self.training,
shape)
10 changes: 6 additions & 4 deletions src/brevitas/proxy/groupwise_int_runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def apply_input_view(self, x):
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor:
if x is None:
x: Union[torch.Tensor, GroupwiseIntQuantTensor]) -> GroupwiseIntQuantTensor:
if isinstance(qt_args, tuple):
value, scale, zero_point, bit_width = qt_args
out = GroupwiseIntQuantTensor(
value,
Expand All @@ -40,7 +40,8 @@ def create_quant_tensor(
self.group_dim,
bit_width,
self.is_signed,
self.training)
self.training,
x.shape)
else:
out = GroupwiseIntQuantTensor(
qt_args,
Expand All @@ -50,5 +51,6 @@ def create_quant_tensor(
self.group_dim,
x.bit_width,
x.signed,
self.training)
self.training,
x.shape)
return out
2 changes: 1 addition & 1 deletion src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out.detach(),
metadata_only=self.cache_inference_quant_weight_metadata_only)
else: # quantization disabled
out = self.apply_input_view(x)
out = x
return out


Expand Down
18 changes: 7 additions & 11 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,8 @@ def init_tensor_quant(self):

@abstractmethod
def create_quant_tensor(
self,
qt_args: Union[torch.Tensor, Tuple[Any]],
x: Optional[QuantTensor] = None) -> QuantTensor:
self, qt_args: Union[torch.Tensor, Tuple[Any]], x: Union[Tensor,
QuantTensor]) -> QuantTensor:
# Supports the following:
# - qt_args as tuple of Tensors and bools = standard quant activations
# - qt_args as Tensor and x as QuantTensor = passthrough activation
Expand All @@ -181,8 +180,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
elif not self.is_quant_enabled:
# A tuple helps later with control flows
# The second None value is used later
# If quant is not enabled, we still apply input_view in the case of groupwise + padding
y = self.apply_input_view(self.fused_activation_quant_proxy.activation_impl(y))
y = self.fused_activation_quant_proxy.activation_impl(y)
y = (y, None)
else:
y = self.fused_activation_quant_proxy(y)
Expand All @@ -194,7 +192,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
else:
# If the second value (i.e., scale) is None, then quant is disabled
if y[1] is not None:
out = self.create_quant_tensor(y)
out = self.create_quant_tensor(y, x=x)
elif self.is_passthrough_act and isinstance(x, QuantTensor):
# preserve scale/zp/bit/sign even without output quant
y = y[0]
Expand Down Expand Up @@ -224,11 +222,9 @@ def bit_width(self, force_eval=True):
return self.retrieve_attribute('bit_width', force_eval)

def create_quant_tensor(
self,
qt_args: Union[Tensor, Tuple[Any]],
x: Optional[IntQuantTensor] = None) -> IntQuantTensor:

if x is None:
self, qt_args: Union[torch.Tensor, Tuple[Any]],
x: Union[Tensor, IntQuantTensor]) -> IntQuantTensor:
if isinstance(qt_args, tuple):
out = IntQuantTensor(*qt_args, self.is_signed, self.training)
else:
out = IntQuantTensor(
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/quant_tensor/base_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional
from typing import List, NamedTuple, Optional, Tuple

from torch import Tensor

Expand Down Expand Up @@ -129,6 +129,7 @@ class GroupwiseFloatQuantTensorBase(NamedTuple):
nan_values: List[str]
signed_t: Tensor
training_t: Tensor
dequant_shape: Optional[Tuple] = None


class GroupwisIntQuantTensorBase(NamedTuple):
Expand All @@ -140,6 +141,7 @@ class GroupwisIntQuantTensorBase(NamedTuple):
bit_width: Tensor
signed_t: Tensor
training_t: Tensor
dequant_shape: Optional[Tuple] = None


def _unpack_quant_tensor(input_data):
Expand Down
22 changes: 7 additions & 15 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __new__(
inf_values,
nan_values,
signed,
training):
training,
dequant_shape=None):

if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, dtype=torch.float)
Expand Down Expand Up @@ -63,7 +64,8 @@ def __new__(
inf_values,
nan_values,
signed,
training)
training,
dequant_shape)
return quant_tensor

@property
Expand All @@ -89,19 +91,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
39 changes: 24 additions & 15 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@

class GroupwiseIntQuantTensor(GroupwisIntQuantTensorBase, QuantTensor):

def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training):
def __new__(
cls,
value,
scale,
zero_point,
group_size,
group_dim,
bit_width,
signed,
training,
dequant_shape=None):

if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, dtype=torch.float)
Expand All @@ -31,7 +41,16 @@ def __new__(cls, value, scale, zero_point, group_size, group_dim, bit_width, sig
if not isinstance(training, torch.Tensor):
training = torch.tensor(training, dtype=torch.bool)
quant_tensor = super().__new__(
cls, value, scale, zero_point, group_size, group_dim, bit_width, signed, training)
cls,
value,
scale,
zero_point,
group_size,
group_dim,
bit_width,
signed,
training,
dequant_shape)
return quant_tensor

@property
Expand All @@ -58,19 +77,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
return func(*args, **kwargs)

def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_scale = self.scale_
if self.zero_point_.shape != ():
new_zp = self.zero_point_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
new_zp = self.zero_point_

return new_value, new_scale, new_zp
from brevitas.utils.quant_utils import groupwise_dequant_expand
return groupwise_dequant_expand(
self.value_, self.scale_, self.zero_point_, self.group_dim, self.dequant_shape)

@staticmethod
def from_expanded(value, group_size, group_dim, compress=False):
Expand Down
Loading
Loading