From 8752c4c191af6bad6ea0e0f0efd217a4032bf667 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 10 Jan 2025 19:58:52 -0800 Subject: [PATCH] Clean up linear_int8_dynamic_activation_intx_weight_subclass Summary: Cleans up layout and quantization API: ``` int8_dynamic_activation_intx_weight( group_size: int = 128, bit_width: int = 4, has_weight_zeros: bool = False, weight_mapping_type=MappingType.ASYMMETRIC, act_mapping_type=MappingType.ASYMMETRIC, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ) ``` int8_dynamic_activation_intx_weight is now very similar to int8_dynamic_activation_int4_weight. By passing bit_width=4, has_weight_zeros=false, and layout=PlainLayout(), it should be numerically identical (but slower). The fallback option is removed and instead relies on using PlainLayout(). Reviewed By: jerryzh168 Differential Revision: D67821939 --- torchao/_models/llama/generate.py | 20 +- torchao/dtypes/uintx/plain_layout.py | 18 +- torchao/dtypes/utils.py | 6 +- .../_linear_8bit_act_xbit_weight_layout.py | 374 ------------------ torchao/experimental/docs/readme.md | 4 +- ...8_dynamic_activation_intx_weight_layout.py | 275 +++++++++++++ torchao/experimental/quant_api.py | 145 +++++-- ..._dynamic_activation_intx_weight_layout.py} | 76 ++-- 8 files changed, 436 insertions(+), 482 deletions(-) delete mode 100644 torchao/experimental/_linear_8bit_act_xbit_weight_layout.py create mode 100644 torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py rename torchao/experimental/tests/{test_linear_int8_dynamic_activation_intx_weight_subclass.py => test_packed_linear_int8_dynamic_activation_intx_weight_layout.py} (62%) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 5635ed8d23..36aebd3dcd 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -543,32 +543,18 @@ def ffn_or_attn_only(mod, fqn): from torchao.experimental.quant_api import ( int8_dynamic_activation_intx_weight, ) - - assert ( - precision == torch.float32 - ), "int8_dynamic_activation_intx_weight requires fp32 precision" - - try: - torch.ops.torchao._pack_8bit_act_4bit_weight - except: - print( - "Unable to load experimental torchao kernels. Performance will be slow." - ) - print( - "To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU" - ) + assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires using precision=torch.float32" # Quantize model _quant_args = quantization.split("-") - nbit = int(_quant_args[1]) - assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8" + bit_width = int(_quant_args[1]) group_size = int(_quant_args[2]) has_weight_zeros = bool(_quant_args[3]) quantize_( model, int8_dynamic_activation_intx_weight( + bit_width=bit_width, group_size=group_size, - nbit=nbit, has_weight_zeros=has_weight_zeros, ), ) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 502e3c13e9..f47757fb77 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -38,7 +38,7 @@ def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): kwargs = {} @@ -55,7 +55,7 @@ def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): self.int_data = int_data @@ -64,6 +64,8 @@ def __init__( self._layout = _layout def __tensor_flatten__(self): + if self.zero_point is None: + return ["int_data", "scale"], [self._layout] return ["int_data", "scale", "zero_point"], [self._layout] @classmethod @@ -73,7 +75,7 @@ def __tensor_unflatten__( int_data, scale, zero_point = ( tensor_data_dict["int_data"], tensor_data_dict["scale"], - tensor_data_dict["zero_point"], + tensor_data_dict.get("zero_point", None), ) (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) @@ -83,7 +85,9 @@ def to(self, *args, **kwargs): return self.__class__( self.int_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]) + if self.zero_point is not None + else None, self._layout, ) @@ -91,7 +95,7 @@ def _apply_fn_to_data(self, fn): return self.__class__( fn(self.int_data), fn(self.scale), - fn(self.zero_point), + fn(self.zero_point) if self.zero_point is not None else None, self._layout, ) @@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): return PlainAQTTensorImpl( aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), - self.zero_point.view(-1), + self.zero_point.view(-1) if self.zero_point is not None else None, self._layout, ) else: @@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): __torch_function__ = torch._C._disabled_torch_function_impl - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.int_data, self.scale, self.zero_point def get_layout(self) -> Layout: diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 774071f856..0952b2a4bf 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor): the underlying implementation of a AQT based on layout """ - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Get the plain (unpacked) Tensor for the tensor impl Returns data, scale and zero_point @@ -103,7 +103,7 @@ def from_plain( cls, data: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: Optional[torch.Tensor], _layout: Layout, ): """Construct a TensorImpl from data, scale, zero_point and the _layout""" diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py deleted file mode 100644 index 1f24c91ed2..0000000000 --- a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from enum import Enum, auto -from typing import Optional, Tuple - -import torch -from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.dtypes.affine_quantized_tensor import ( - register_layout, -) -from torchao.dtypes.affine_quantized_tensor_ops import ( - register_aqt_quantized_linear_dispatch, -) -from torchao.dtypes.utils import AQTTensorImpl, Layout -from torchao.quantization.quant_api import to_affine_quantized_intx -from torchao.quantization.quant_primitives import ( - MappingType, - ZeroPointDomain, -) - -logger = logging.getLogger(__name__) -logger.setLevel(logging.WARNING) - -import sys - -handler = logging.StreamHandler(sys.stdout) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -handler.setFormatter(formatter) -logger.addHandler(handler) - - -class Target(Enum): - """Enum that indicates the backend target""" - - NATIVE = auto() - FALLBACK = auto() - - -def target_from_str(target: str) -> Target: - if target.lower() == "native": - return Target.NATIVE - elif target.lower() == "fallback": - return Target.FALLBACK - else: - raise ValueError(f"Invalid target: {target}") - - -# This format is intended for use with int8 dynamic quantization -class Linear8BitActXBitWeightLayout(Layout): - nbit: int - group_size: int - - # The target platform for the layout, either 'native' or 'fallback'. - target: Target - - def __init__( - self, - nbit: int, - group_size: int, - target: str, - ): - assert nbit <= 8 - self.nbit = nbit - self.group_size = group_size - self.target = target_from_str(target) - - def extra_repr(self): - return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" - - -def _pack_weights_native( - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout: Layout, -): - assert isinstance(layout, Linear8BitActXBitWeightLayout) - assert layout.target == Target.NATIVE - nbit = layout.nbit - group_size = layout.group_size - has_weight_zeros = zero_point is not None - - if has_weight_zeros: - args = [ - int_data.to(torch.int8), - scale.reshape(-1), - zero_point.reshape(-1).to(torch.int8), - torch.empty(0, group_size, dtype=torch.int8), - ] - else: - args = [ - int_data.to(torch.int8), - scale.reshape(-1), - torch.empty(0, group_size, dtype=torch.int8), - ] - - wzp_suffix = "" if has_weight_zeros else "0zp" - return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( - *args - ) - - -@register_layout(Linear8BitActXBitWeightLayout) -class Linear8BitActXBitWeightAQTTensorImpl(AQTTensorImpl): - def __new__( - cls, - packed_weight: torch.Tensor, - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["dtype"] = packed_weight.dtype - assert not packed_weight.requires_grad - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, Linear8BitActXBitWeightLayout) - - # In the native case, scale and zero_point information is inside - # the packed_weight - if _layout.target == Target.NATIVE: - assert scale is None - assert zero_point is None - - self.packed_weight = packed_weight - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __repr__(self): - layout = self.get_layout() - return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout={layout})" - - def get_layout(self) -> Layout: - return self._layout - - def get_plain( - self, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: - if self.get_layout().target == Target.FALLBACK: - return self.packed_weight, self.scale, self.zero_point - raise NotImplementedError( - "get_plain is not supported for Linear8BitActXBitWeightAQTTensorImpl when target is not fallback" - ) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - layout: Layout, - ): - assert isinstance(layout, Linear8BitActXBitWeightLayout) - - try: - if layout.target == Target.NATIVE: - packed_weight = _pack_weights_native( - int_data, scale, zero_point, layout - ) - scale = None - zero_point = None - return cls(packed_weight, scale, zero_point, layout) - except Exception as e: - logger.warning( - f"A failure occurred when packing weights with Linear8BitActXBitWeightLayout.target={layout.target}: {e}\n" - + "Falling back to **slow** implementation Linear8BitActXBitWeightLayout.target=fallback." - ) - layout.target = Target.FALLBACK - - # Fallback - assert layout.target == Target.FALLBACK - packed_weight = int_data.to(torch.int32) - return cls(packed_weight, scale, zero_point, layout) - - def _apply_fn_to_data(self, fn): - self.packed_weight = fn(self.packed_weight) - if self.scale is not None: - self.scale = fn(self.scale) - - if self.zero_point is not None: - self.zero_point = fn(self.zero_point) - return self - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is torch.ops.aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is torch.ops.aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - raise NotImplementedError( - f"Linear8BitActXBitWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - if self.get_layout().target == Target.NATIVE: - return ["packed_weight"], [self.get_layout()] - - # fallback - assert self.get_layout().target == Target.FALLBACK - if self.zero_point is None: - return ["packed_weight", "scale"], [self.get_layout()] - return ["packed_weight", "scale", "zero_point"], [self.get_layout()] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale, zero_point = ( - tensor_data_dict["packed_weight"], - tensor_data_dict.get("scale", None), - tensor_data_dict.get("zero_point", None), - ) - (layout,) = tensor_attributes - return cls(packed_weight, scale, zero_point, layout) - - -def _linear_int8_dynamic_activation_intx_weight_check( - input_tensor, weight_tensor, bias -): - layout = weight_tensor.tensor_impl.get_layout() - return isinstance(layout, Linear8BitActXBitWeightLayout) and bias is None - - -def _linear_int8_dynamic_activation_intx_weight_fallback_impl( - input_tensor, weight_tensor, bias -): - assert weight_tensor.tensor_impl.get_layout().target == Target.FALLBACK - assert bias is None - - def _impl_2d(input_tensor, weight_tensor): - assert input_tensor.dim() == 2 - assert weight_tensor.dim() == 2 - - m, k = input_tensor.shape - n, k_ = weight_tensor.shape - assert k_ == k - - weights_dequantized = weight_tensor.dequantize() - - # Quantize activations - activations_dequantized = to_affine_quantized_intx( - input_tensor, - mapping_type=MappingType.ASYMMETRIC, - block_size=(1, k), - target_dtype=torch.int32, - quant_min=-128, - quant_max=127, - eps=0.0, - zero_point_dtype=torch.int32, - preserve_zero=True, - zero_point_domain=ZeroPointDomain.INT, - use_hqq=False, - ).dequantize() - - return torch.matmul( - activations_dequantized, weights_dequantized.transpose(1, 0) - ) - - if input_tensor.dim() == 2: - return _impl_2d(input_tensor, weight_tensor) - - assert input_tensor.dim() >= 3 - lead_shape = input_tensor.shape[0:-2] - m, k = input_tensor.shape[-2], input_tensor.shape[-1] - n, k_ = weight_tensor.shape - assert k_ == k - - res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) - res = res.reshape(*lead_shape, m, n) - - return res - - -def _linear_int8_dynamic_activation_intx_weight_native_impl( - input_tensor, weight_tensor, bias -): - assert weight_tensor.tensor_impl.get_layout().target == Target.NATIVE - assert bias is None - - def _impl_2d(input_tensor, weight_tensor): - assert input_tensor.dim() == 2 - assert weight_tensor.dim() == 2 - - m, k = input_tensor.shape - n, k_ = weight_tensor.shape - assert k_ == k - group_size = weight_tensor.tensor_impl.get_layout().group_size - packed_weight = weight_tensor.tensor_impl.packed_weight - - # TODO(T200095131): convert self.n, self.k, self.group_size to - # int when supported by AOTI - args = ( - input_tensor, - packed_weight, - torch.empty(0, group_size, dtype=torch.int8), - torch.empty(0, n, dtype=torch.int8), - torch.empty(0, k, dtype=torch.int8), - ) - - has_weight_zeros = weight_tensor.zero_point_domain != ZeroPointDomain.NONE - - assert len(weight_tensor.block_size) == 2 - assert weight_tensor.block_size[0] == 1 - group_size = weight_tensor.block_size[1] - assert group_size == weight_tensor.tensor_impl.get_layout().group_size - nbit = weight_tensor.tensor_impl.get_layout().nbit - - n, k = weight_tensor.shape - m, k_ = input_tensor.shape - assert k_ == k - - packed_weight = weight_tensor.tensor_impl.packed_weight - wzp_suffix = "" if has_weight_zeros else "0zp" - return getattr( - torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" - )(*args) - - if input_tensor.dim() == 2: - return _impl_2d(input_tensor, weight_tensor) - - assert input_tensor.dim() >= 3 - lead_shape = input_tensor.shape[0:-2] - m, k = input_tensor.shape[-2], input_tensor.shape[-1] - n, k_ = weight_tensor.shape - assert k_ == k - - res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) - res = res.reshape(*lead_shape, m, n) - return res - - -def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): - target = weight_tensor.tensor_impl.get_layout().target - if target == Target.NATIVE: - return _linear_int8_dynamic_activation_intx_weight_native_impl( - input_tensor, weight_tensor, bias - ) - - if target == Target.FALLBACK: - return _linear_int8_dynamic_activation_intx_weight_fallback_impl( - input_tensor, weight_tensor, bias - ) - - assert False, f"Unknown target {target}" - - -register_aqt_quantized_linear_dispatch( - _linear_int8_dynamic_activation_intx_weight_check, - _linear_int8_dynamic_activation_intx_weight_impl, -) diff --git a/torchao/experimental/docs/readme.md b/torchao/experimental/docs/readme.md index c1bfa5c32a..1da05ae21f 100644 --- a/torchao/experimental/docs/readme.md +++ b/torchao/experimental/docs/readme.md @@ -60,15 +60,15 @@ from torchao.quantization.quant_api import quantize_ quantize_( my_model, int8_dynamic_activation_intx_weight( + bit_width=4, group_size=256, - nbit=4, has_weight_zeros=False, ), ) ``` If you get stuck, consult -`tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py`. +`tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py`. ## Available in torchchat diff --git a/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py new file mode 100644 index 0000000000..7b2b1da145 --- /dev/null +++ b/torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.affine_quantized_tensor import ( + register_layout, +) +from torchao.dtypes.affine_quantized_tensor_ops import ( + register_aqt_quantized_linear_dispatch, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout): + bit_width: Optional[int] + group_size: Optional[int] + has_weight_zeros: Optional[bool] + + def __init__( + self, + bit_width: Optional[int] = None, + group_size: Optional[int] = None, + has_weight_zeros: Optional[bool] = None, + ): + if bit_width is not None: + assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8" + if group_size is not None: + assert group_size >= 1, f"group_size must be positive, got {group_size}" + + self.bit_width = bit_width + self.group_size = group_size + self.has_weight_zeros = has_weight_zeros + + if not self.has_params_set(): + assert ( + self.bit_width is None + and self.group_size is None + and self.has_weight_zeros is None + ), "bit_width, group_size, and has_weight_zeros must be None if has_params_set is False" + + def extra_repr(self): + return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}" + + def has_params_set(self) -> bool: + return ( + (self.bit_width is not None) + and (self.group_size is not None) + and (self.has_weight_zeros is not None) + ) + + +@register_layout(PackedLinearInt8DynamicActivationIntxWeightLayout) +class PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl(AQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + _layout: Layout, + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor: torch.Tensor, + n_tensor: torch.Tensor, + k_tensor: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = packed_weight.dtype + assert not packed_weight.requires_grad + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + _layout: Layout, + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor: torch.Tensor, + n_tensor: torch.Tensor, + k_tensor: torch.Tensor, + ): + assert isinstance(_layout, PackedLinearInt8DynamicActivationIntxWeightLayout) + self.packed_weight = packed_weight + self._layout = _layout + self.group_size_tensor = group_size_tensor + self.n_tensor = n_tensor + self.k_tensor = k_tensor + + def __repr__(self): + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, layout={self.get_layout()})" + + def get_layout(self) -> Layout: + return self._layout + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError( + "get_plain is not implemented for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" + ) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + layout: Layout, + ): + assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) + assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + n, k = int_data.shape + group_size_tensor = torch.empty(0, layout.group_size, dtype=torch.int8) + n_tensor = torch.empty(0, n, dtype=torch.int8) + k_tensor = torch.empty(0, k, dtype=torch.int8) + + if layout.has_weight_zeros: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + zero_point.reshape(-1).to(torch.int8), + group_size_tensor, + ] + else: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + group_size_tensor, + ] + + wzp_suffix = "" if layout.has_weight_zeros else "0zp" + packed_weight = getattr( + torch.ops.torchao, + f"_pack_8bit_act_{layout.bit_width}bit{wzp_suffix}_weight", + )(*args) + + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + self.group_size_tensor = fn(self.group_size_tensor) + self.n_tensor = fn(self.n_tensor) + self.k_tensor = fn(self.k_tensor) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is torch.ops.aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + return ["packed_weight", "group_size_tensor", "n_tensor", "k_tensor"], [ + self.get_layout() + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight = tensor_data_dict["packed_weight"] + + # TODO(T200095131): remove group_size_tensor, n_tensor, k_tensor + # when AOTI supports int + group_size_tensor = tensor_data_dict["group_size_tensor"] + n_tensor = tensor_data_dict["n_tensor"] + k_tensor = tensor_data_dict["k_tensor"] + + (layout,) = tensor_attributes + return cls(packed_weight, layout, group_size_tensor, n_tensor, k_tensor) + + +def _linear_check(input_tensor, weight_tensor, bias): + layout = weight_tensor.tensor_impl.get_layout() + return isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) and ( + bias is None + ) + + +def _linear_impl(input_tensor, weight_tensor, bias): + assert ( + bias is None + ), "bias in linear is not supported for PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl" + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.tensor_impl.get_layout().group_size + + assert group_size == weight_tensor.tensor_impl.group_size_tensor.shape[1] + assert n == weight_tensor.tensor_impl.n_tensor.shape[1] + assert k == weight_tensor.tensor_impl.k_tensor.shape[1] + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + args = ( + input_tensor, + weight_tensor.tensor_impl.packed_weight, + weight_tensor.tensor_impl.group_size_tensor, + weight_tensor.tensor_impl.n_tensor, + weight_tensor.tensor_impl.k_tensor, + ) + + has_weight_zeros = weight_tensor.zero_point_domain != ZeroPointDomain.NONE + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + assert group_size == weight_tensor.block_size[1] + bit_width = weight_tensor.tensor_impl.get_layout().bit_width + + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{bit_width}bit{wzp_suffix}_weight" + )(*args) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + res = _impl_2d(input_tensor.reshape(-1, k), weight_tensor) + res = res.reshape(*lead_shape, m, n) + return res + + +register_aqt_quantized_linear_dispatch( + _linear_check, + _linear_impl, +) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index ce99e250ef..ef96acbf11 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -482,58 +482,121 @@ def quantize(self, model: nn.Module) -> nn.Module: return model +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) +from torchao.quantization.linear_activation_quantized_tensor import ( + to_linear_activation_quantized, +) +from torchao.quantization.quant_api import ( + MappingType, + ZeroPointDomain, + _get_linear_subclass_inserter, + to_affine_quantized_intx, +) +from torchao.quantization.utils import _get_per_token_block_size + + def int8_dynamic_activation_intx_weight( + bit_width: int = 4, group_size: int = 128, - nbit: int = 4, has_weight_zeros: bool = False, - target: str = "native", + weight_mapping_type=MappingType.ASYMMETRIC, + act_mapping_type=MappingType.ASYMMETRIC, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # PlainLayout() also works, but will be slow ): - from torchao.experimental._linear_8bit_act_xbit_weight_layout import ( - Linear8BitActXBitWeightLayout, - ) - from torchao.quantization.quant_api import ( - MappingType, - ZeroPointDomain, - _get_linear_subclass_inserter, - to_affine_quantized_intx, - ) + """ + Dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers. + More specifically, activations are dynamically quantized to 8-bits in a channelwise manner with scales and zeros. + Weights are quantized with scales and optionally zeros (controlled by has_weight_zeros) in a groupwise + manner using bit_width-bits. + + args: + bit_width: The number of bits to use for the weight quantization. Must be between 1 and 8. + group_size: The number of weight values to quantize together. Must be a divisor of the last dimension of the weight tensor. + If group_size is -1 or None, then group_size will be set to the last dimension of the weight tensor (channelwise quantization). + has_weight_zeros: Whether or not to include zeros in the weight quantization. + weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + act_mapping_type: The type of mapping to use for the activation quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC. + layout: The layout to use for the packed weight tensor. Must be PackedLinearInt8DynamicActivationIntxWeightLayout (default) or PlainLayout. + The layout does not affect the quantization numerically and both layouts will give the same results. PlainLayout is a generic layout + that works on all devices, but it is much slower than PackedLinearInt8DynamicActivationIntxWeightLayout on CPU. + PackedLinearInt8DynamicActivationIntxWeightLayout is a specialized layout for CPU performance. + When using PackedLinearInt8DynamicActivationIntxWeightLayout, + - The weight tensor must have device=CPU + - The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32) + - act_mapping_type must be MappingType.ASYMMETRIC + """ + try: + torch.ops.torchao._pack_8bit_act_4bit_weight + except AttributeError: + raise Exception( + "TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU." + + " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance." + ) + + group_size_arg = group_size + layout_arg = layout def apply(weight): + group_size = group_size_arg + if group_size is None or group_size == -1: + group_size = weight.shape[-1] + assert weight.shape[-1] % group_size == 0 - assert weight.device == torch.device("cpu"), "Only CPU is supported" - use_hqq = False - layout = Linear8BitActXBitWeightLayout( - nbit=nbit, group_size=group_size, target=target - ) - mapping_type = MappingType.ASYMMETRIC - eps = torch.finfo(torch.float32).eps - block_size = (1, group_size) - target_dtype = torch.int32 - quant_min = -(1 << (nbit - 1)) - quant_max = (1 << (nbit - 1)) - 1 - zero_point_dtype = torch.int8 - preserve_zero = has_weight_zeros - zero_point_domain = ( - ZeroPointDomain.INT if has_weight_zeros else ZeroPointDomain.NONE - ) - # Note: this works differently than other quantizers because the dynamic - # activation quantization is fused with the kernel/op (and static activation quantization - # is not supported). - return to_affine_quantized_intx( + + layout = layout_arg + if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): + assert ( + weight.device == torch.device("cpu") + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU" + assert ( + weight.dtype == torch.float32 + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.dtype=float32" + assert ( + act_mapping_type == MappingType.ASYMMETRIC + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" + assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set" + layout = PackedLinearInt8DynamicActivationIntxWeightLayout( + bit_width=bit_width, + group_size=group_size, + has_weight_zeros=has_weight_zeros, + ) + + quant_min = -(1 << (bit_width - 1)) + quant_max = (1 << (bit_width - 1)) - 1 + weight = to_affine_quantized_intx( weight, - mapping_type, - block_size, - target_dtype, - quant_min, - quant_max, - eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain, + mapping_type=weight_mapping_type, + block_size=(1, group_size), + target_dtype=torch.int32, + quant_min=quant_min, + quant_max=quant_max, + eps=torch.finfo(torch.float32).eps, + zero_point_dtype=torch.int8, + preserve_zero=has_weight_zeros, + zero_point_domain=ZeroPointDomain.INT + if has_weight_zeros + else ZeroPointDomain.NONE, _layout=layout, - use_hqq=use_hqq, ) + # Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused + # with the kernel and it should not be applied separately + if not isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): + activation_quant_func = lambda x: to_affine_quantized_intx( + x, + mapping_type=act_mapping_type, + block_size=_get_per_token_block_size(x), + target_dtype=torch.int32, + quant_min=-128, # lower bound of int8 + quant_max=127, # upper bound of int8 + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + ) + weight = to_linear_activation_quantized(weight, activation_quant_func) + return weight + return _get_linear_subclass_inserter(apply) diff --git a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py similarity index 62% rename from torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py rename to torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py index 61f6c6cc01..f7bae61fec 100644 --- a/torchao/experimental/tests/test_linear_int8_dynamic_activation_intx_weight_subclass.py +++ b/torchao/experimental/tests/test_packed_linear_int8_dynamic_activation_intx_weight_layout.py @@ -10,33 +10,43 @@ import torch +from torchao.dtypes import PlainLayout +from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( + PackedLinearInt8DynamicActivationIntxWeightLayout, +) from torchao.experimental.quant_api import ( - _Int8DynActIntxWeightQuantizedLinearFallback, int8_dynamic_activation_intx_weight, ) from torchao.quantization.quant_api import quantize_ from torchao.utils import unwrap_tensor_subclass -class TestInt8DynamicActivationIntxWeight(unittest.TestCase): +class TestPackedLinearInt8DynamicActivationIntxWeightLayout(unittest.TestCase): def test_accuracy(self): + """ + Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + its results to the results of a reference model that uses PlainLayout() + """ group_size = 128 m = 1 n = 1071 k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) + activations = torch.randn(m, k) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) - for nbit in [1, 2, 3, 4, 5, 6, 7, 8]: + for bit_width in [1, 2, 3, 4, 5, 6, 7, 8]: for has_weight_zeros in [True, False]: - print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + print( + f"Testing bit_width={bit_width}, has_weight_zeros={has_weight_zeros}" + ) quantized_model = copy.deepcopy(model) quantize_( quantized_model, int8_dynamic_activation_intx_weight( + bit_width=bit_width, group_size=group_size, - nbit=nbit, has_weight_zeros=has_weight_zeros, + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), # default ), ) @@ -44,10 +54,10 @@ def test_accuracy(self): quantize_( quantized_model_reference, int8_dynamic_activation_intx_weight( + bit_width=bit_width, group_size=group_size, - nbit=nbit, has_weight_zeros=has_weight_zeros, - target="fallback", + layout=PlainLayout(), ), ) @@ -55,44 +65,30 @@ def test_accuracy(self): result = quantized_model(activations) expected_result = quantized_model_reference(activations) - # TODO: remove expected_result2 checks when we deprecate non-subclass API - reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() - reference_impl.quantize_and_pack_weights( - model[0].weight, nbit, group_size, has_weight_zeros - ) - expected_result2 = reference_impl(activations) - num_mismatch_at_low_tol = 0 - num_mismatch_at_low_tol2 = 0 num_total = result.reshape(-1).shape[0] for i in range(num_total): actual_val = result.reshape(-1)[i] expected_val = expected_result.reshape(-1)[i] - expected_val2 = expected_result2.reshape(-1)[i] self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) if not torch.allclose(actual_val, expected_val): num_mismatch_at_low_tol += 1 - self.assertTrue( - torch.allclose( - expected_val, expected_val2, atol=1e-2, rtol=1e-1 - ) - ) - if not torch.allclose(expected_val, expected_val2): - num_mismatch_at_low_tol2 += 1 - # Assert at most 5% of entries are not close at a low tolerance self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) - self.assertTrue(num_mismatch_at_low_tol2 / num_total <= 0.01) def test_export_compile_aoti(self): - group_size = 32 + """ + Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with + torch.export.export, torch.compile, and AOTI. + """ + group_size = -1 m = 3 k0 = 512 k1 = 256 k2 = 128 k3 = 1024 - nbit = 4 + bit_width = 4 has_weight_zeros = True layers = [ torch.nn.Linear(k0, k1, bias=False), @@ -106,35 +102,39 @@ def test_export_compile_aoti(self): quantize_( model, int8_dynamic_activation_intx_weight( + bit_width=bit_width, group_size=group_size, - nbit=nbit, has_weight_zeros=has_weight_zeros, - target="native", + layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ), ) + eager_results = model(activations) unwrapped_model = copy.deepcopy(model) unwrap_tensor_subclass(model) print("Exporting quantized model") - torch.export.export(model, (activations,), strict=True) + exported = torch.export.export(model, (activations,), strict=True) + exported_results = exported.module()(activations) + self.assertTrue(torch.allclose(eager_results, exported_results)) print("Compiling quantized model") compiled = torch.compile(unwrapped_model) with torch.no_grad(): - compiled(activations) + compiled_results = compiled(activations) + self.assertTrue(torch.allclose(eager_results, compiled_results)) with tempfile.TemporaryDirectory() as tmpdirname: + package_path = f"{tmpdirname}/model.pt2" print("Exporting quantized model with AOTI") - torch._export.aot_compile( - model, - (activations,), - options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + torch._inductor.aoti_compile_and_package( + exported, package_path=package_path ) print("Running quantized model in AOTI") - fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") - fn(activations) + fn = torch._inductor.aoti_load_package(package_path) + aoti_results = fn(activations) + self.assertTrue(torch.allclose(eager_results, aoti_results)) if __name__ == "__main__":