From 79a7f52d74459d45f8928942e19e207856f088b4 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 6 Nov 2024 08:53:35 -0800 Subject: [PATCH 01/11] Refactored files --- test/dtypes/test_affine_quantized.py | 2 +- torchao/dtypes/__init__.py | 15 +- torchao/dtypes/affine_quantized_tensor.py | 1520 +---------------- torchao/dtypes/affine_quantized_tensor_ops.py | 295 ++++ torchao/dtypes/floatx/__init__.py | 3 +- torchao/dtypes/floatx/float8_layout.py | 261 +++ ...floatx.py => floatx_tensor_core_layout.py} | 66 +- torchao/dtypes/uintx/__init__.py | 5 + torchao/dtypes/uintx/block_sparse_layout.py | 204 +++ torchao/dtypes/uintx/marlin_sparse_layout.py | 240 +++ torchao/dtypes/uintx/semi_sparse_layout.py | 94 + .../dtypes/uintx/tensor_core_tiled_layout.py | 323 ++++ torchao/dtypes/{ => uintx}/uint4.py | 0 torchao/dtypes/uintx/uint8.py | 114 ++ torchao/dtypes/uintx/uint8_layout.py | 118 ++ torchao/dtypes/uintx/uintx.py | 5 +- torchao/dtypes/utils.py | 40 + torchao/sparsity/utils.py | 1 + 18 files changed, 1795 insertions(+), 1511 deletions(-) create mode 100644 torchao/dtypes/affine_quantized_tensor_ops.py create mode 100644 torchao/dtypes/floatx/float8_layout.py rename torchao/dtypes/floatx/{floatx.py => floatx_tensor_core_layout.py} (87%) create mode 100644 torchao/dtypes/uintx/block_sparse_layout.py create mode 100644 torchao/dtypes/uintx/marlin_sparse_layout.py create mode 100644 torchao/dtypes/uintx/semi_sparse_layout.py create mode 100644 torchao/dtypes/uintx/tensor_core_tiled_layout.py rename torchao/dtypes/{ => uintx}/uint4.py (100%) create mode 100644 torchao/dtypes/uintx/uint8.py create mode 100644 torchao/dtypes/uintx/uint8_layout.py diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index dd7e679e56..540d85e032 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -92,7 +92,7 @@ def test_to_device(self, apply_quant): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_register_new_dispatch(self): - from torchao.dtypes.affine_quantized_tensor import ( + from torchao.dtypes.affine_quantized_tensor_ops import ( register_aqt_quantized_linear_dispatch, deregister_aqt_quantized_linear_dispatch, ) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 4ab0c3f701..032948cd51 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,6 +1,6 @@ from .nf4tensor import NF4Tensor, to_nf4 # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor -from .uint4 import UInt4Tensor +from .uintx import UInt4Tensor from .affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized_intx, @@ -9,15 +9,22 @@ to_affine_quantized_fpx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, + PlainAQTTensorImpl, +) +from .affine_quantized_tensor_ops import * +from .utils import ( Layout, PlainLayout, - SemiSparseLayout, - TensorCoreTiledLayout, +) +from .floatx import ( Float8Layout, Float8AQTTensorImpl, +) +from .uintx import ( + SemiSparseLayout, + TensorCoreTiledLayout, MarlinSparseLayout, ) - __all__ = [ "NF4Tensor", "to_nf4", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 03cec525f4..31f775e92e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,126 +1,37 @@ -import torch -from typing import Tuple, Optional, Union, List -import torchao.ops -from collections import defaultdict -import functools +from dataclasses import dataclass +import logging import math +from typing import Optional, Tuple, Union + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import Layout, PlainLayout from torchao.quantization.quant_primitives import ( - _get_reduction_params, - choose_qparams_affine, - quantize_affine, - dequantize_affine, - ZeroPointDomain, - MappingType, - int_scaled_matmul, - choose_qparams_and_quantize_affine_hqq, FP8_TYPES, + MappingType, + ZeroPointDomain, + choose_qparams_affine, choose_qparams_affine_floatx, - quantize_affine_floatx, + choose_qparams_and_quantize_affine_hqq, + dequantize_affine, dequantize_affine_floatx, + quantize_affine, + quantize_affine_floatx, ) -from torchao.quantization.utils import ( - pack_tinygemm_scales_and_zeros, -) -from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.utils import ( - Layout, - PlainLayout, - is_device, - get_out_shape, -) -from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - _is_rowwise_scaled -) -from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from dataclasses import dataclass from torchao.utils import ( - find_multiple, - TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, - _is_float8_type, - fill_defaults, + TorchAOBaseTensor, +) +from torchao.dtypes.utils import ( + AQTTensorImpl, ) -import logging logger = logging.getLogger(__name__) - -from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten -############################### -# Base Tensor Impl Subclass # -############################### -class AQTTensorImpl(TorchAOBaseTensor): - """ - Base class for the tensor impl for `AffineQuantizedTensor` - - Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct - the underlying implementation of a AQT based on layout - """ - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Get the plain (unpacked) Tensor for the tensor impl - - Returns data, scale and zero_point - Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors - """ - pass - - def get_layout(self) -> Layout: - pass - - @classmethod - def from_plain( - cls, - data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - """ Construct a TensorImpl from data, scale, zero_point and the _layout""" - pass - - def __repr__(self): - data, scale, zero_point = self.get_plain() - _layout = self.get_layout() - return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" - - ############################## # Tensor Subclass Definition # ############################## - - -class QuantizedLinearNotImplementedError(NotImplementedError): - """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ - pass - - -_AQT_QLINEAR_DISPATCH_TABLE = {} -def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): - """Register a dispatch for quantized linear op with dispatch_condition function and impl function - both takes three arguments: - input_tensor: dimension is (M1, M2, ..., in_features) - weight_tensor: dimension is (out_features, in_features) - bias: dimension is (out_features,) - so that these can be shared by F.linear, aten.mm, aten.addmm dispatches - - Args: - `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch - condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight - `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized - quantized linear implementation - """ - _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl - -def deregister_aqt_quantized_linear_dispatch(dispatch_condition): - if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: - del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] - else: - logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") - class AffineQuantizedTensor(TorchAOBaseTensor): """ Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation: @@ -227,13 +138,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor dq = dq.narrow(dim, 0, dim_size) return dq - @staticmethod - def _quantized_linear_op(input_tensor, weight_tensor, bias): - for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): - if dispatch_condition(input_tensor, weight_tensor, bias): - return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") - def __tensor_flatten__(self): return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @@ -451,7 +355,7 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) - # following are the comments for __torch_function__/__torch_dispatch__, we can clean this up + # following are the comments for __torch_function__/__torch_dispatch__, -> this is defined in affine_quantized_tensor_ops.py # a bit later # Note: we only added cpu path here for 8da4w, this is for executorch, in the future # 1. we'll add cpu/cuda version (int4mm etc.) @@ -472,93 +376,6 @@ def _apply_fn_to_data(self, fn): register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor -@dataclass(frozen=True) -class SemiSparseLayout(Layout): - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - # prune to 2:4 if not already - temp = input.detach() - pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] - temp.view(-1, 4).scatter_(1, pruning_inds, value=0) - return temp - - -@dataclass(frozen=True) -class BlockSparseLayout(Layout): - blocksize: int = 64 - - -@dataclass(frozen=True) -class TensorCoreTiledLayout(Layout): - """ - inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel - """ - inner_k_tiles: int = 8 - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - input = self.pre_process(input) - orig_qparam_shape = scale.shape - new_qparam_shape, reduction_dims = _get_reduction_params(block_size, input.size()) - for dim in reduction_dims: - new_qparam_shape.pop(dim) - change_in_qparam_shape = [new_dim_size-orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape)] - padding_changes=[] - for dim_change in change_in_qparam_shape: - padding_changes = [0, dim_change] + padding_changes - scale = torch.nn.functional.pad(scale, padding_changes) - zero_point = torch.nn.functional.pad(zero_point, padding_changes) - return input, scale, zero_point - - def post_process(self, input: torch.Tensor) -> torch.Tensor: - orig_out_features, orig_in_features = input.shape - in_features = find_multiple(orig_in_features, 1024) - out_features = find_multiple(orig_out_features, 8) - input = torch.nn.functional.pad( - input, - (0, in_features - orig_in_features, 0, out_features - orig_out_features), - ) - return input - - def extra_repr(self): - return f"inner_k_tiles={self.inner_k_tiles}" - - -@dataclass(frozen=True) -class Float8Layout(Layout): - mm_config: Optional[Float8MMConfig] = None - - -@dataclass(frozen=True) -class MarlinSparseLayout(Layout): - - def pre_process(self, input: torch.Tensor) -> torch.Tensor: - """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - - 2º: tensor is injected with 2:4 sparsity - - 3º: transposes it again because the quantization process will compute the scales for dim=-1 - - Args: - input (torch.Tensor): the input tensor to preprocess - - Returns: - torch.Tensor: the preprocessed tensor - """ - from torchao.sparsity.marlin import inject_24 # avoid circular import - input_t = input.t() - w_24, _ = inject_24(input_t, *input_t.shape) - return w_24.t() - @register_layout(PlainLayout) class PlainAQTTensorImpl(AQTTensorImpl): @@ -684,1310 +501,11 @@ def from_plain( assert isinstance(_layout, PlainLayout) return cls(int_data, scale, zero_point, _layout) -@register_layout(SemiSparseLayout) -class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): - """ - TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor - """ - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def get_plain(self): - # Currently we don't have cuSPARSELt expansion routines, so we matmul by - # the identity matrix to get the original dense matrix. This is slow though. - cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(cols, - dtype=self.int_data.dtype, - device=self.int_data.device).t()) - return int_data_expanded, self.scale, self.zero_point - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, SemiSparseLayout) - int_data_compressed = torch._cslt_compress(int_data) - return cls(int_data_compressed, scale, zero_point, _layout) - -@register_layout(BlockSparseLayout) -class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): - bsr_crow_indices: Optional[torch.Tensor] - bsr_col_indices: Optional[torch.Tensor] - bsr_values: Optional[torch.Tensor] - scale: Optional[torch.Tensor] - zero_point: Optional[torch.Tensor] - - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] - - @staticmethod - def __new__( # noqa: PYI034 - cls, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - if bsr_values is None: - raise ValueError("bsr values must be provided!") - else: - previous_tensor = bsr_values - - kwargs = { - "device": previous_tensor.device, - "dtype": previous_tensor.dtype, - "layout": previous_tensor.layout, - "requires_grad": requires_grad, - } - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( # noqa: PYI034 - self, - shape: torch.Size, - bsr_crow_indices: Optional[torch.Tensor], - bsr_col_indices: Optional[torch.Tensor], - bsr_values: Optional[torch.Tensor], - scale: Optional[torch.Tensor], - zero_point: Optional[torch.Tensor], - _layout: Layout, - requires_grad: bool = False, - ): - self.bsr_crow_indices = bsr_crow_indices - self.bsr_col_indices = bsr_col_indices - self.bsr_values = bsr_values - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - inner_tensors = list( - filter(lambda x: getattr(self, x) is not None, self.__slots__) - ) - tensor_meta = (self.shape, self._layout, self.requires_grad) - return inner_tensors, tensor_meta - - @classmethod - def __tensor_unflatten__( - cls, - inner_tensors, - tensor_meta: Tuple[torch.Size, bool], - outer_size, - outer_stride, - ) -> torch.Tensor: - shape, _layout, requires_grad = tensor_meta - return cls( - shape=shape, - bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), - bsr_col_indices=inner_tensors.get("bsr_col_indices", None), - bsr_values=inner_tensors.get("bsr_values", None), - scale=inner_tensors.get("scale", None), - zero_point=inner_tensors.get("zero_point", None), - _layout=_layout, - requires_grad=requires_grad, - ) - - @classmethod - def from_plain(cls, int_data, scale, zero_point, _layout): - bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) - return cls( - shape=int_data.shape, - bsr_crow_indices=bsr_tensor.crow_indices(), - bsr_col_indices=bsr_tensor.col_indices(), - bsr_values=bsr_tensor.values(), - scale=scale, - zero_point=zero_point, - _layout = _layout, - requires_grad=False, - ) - - def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) - return int_data_expanded, self.scale, self.zero_point - - def _apply_fn_to_data(self, func): - return self.__class__( - shape = self.shape, - bsr_crow_indices=func(self.bsr_crow_indices), - bsr_col_indices=func(self.bsr_col_indices), - bsr_values=func(self.bsr_values), - scale=self.scale, - zero_point=self.zero_point, - _layout=self._layout, - requires_grad=self.requires_grad, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - # Need the following for bsr specific functions - if func is aten.crow_indices.default: - return args[0].bsr_crow_indices.detach() - - if func is aten.col_indices.default: - return args[0].bsr_col_indices.detach() - - if func is aten.values.default: - return args[0].bsr_values.detach() - - if func is aten._nnz.default: - return args[0].bsr_values.shape[0] - - raise NotImplementedError( - f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - -@register_layout(MarlinSparseLayout) -class MarlinSparseAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for sparse_marlin_24 layout for affine quantized tensor. - - Can be used with 4 bits and 8 bits quantization. - - Original marlin documentation and information: - https://github.com/IST-DASLab/marlin/tree/master - - Sparse marlin documentation and information: - https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file - - fields: - original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape - group_size (int): the group size used to pack the tensor - num_bits (int): the number of bits used to quantize the tensor - """ - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - meta: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - meta: torch.Tensor, - _layout: Layout, - original_shape: torch.Size, - group_size: int, - num_bits: int, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self.meta = meta - self._layout = _layout - self.original_shape = original_shape - self.group_size = group_size - self.num_bits = num_bits - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - raise NotImplementedError( - f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data = tensor_data_dict["int_data"] - scale = tensor_data_dict["scale"] - zero_point = tensor_data_dict["zero_point"] - meta = tensor_data_dict["meta"] - _layout, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) - - def get_plain(self): - from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import - int_data_expanded, scales_expanded = unpack_from_marlin_24( - self.int_data, - self.scale, - self.meta, - self.original_shape, - self.group_size, - self.num_bits, - ) - int_data_expanded_t = int_data_expanded.t() - scales_expanded_t = scales_expanded.t() - return int_data_expanded_t, scales_expanded_t, self.zero_point - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import - assert isinstance(_layout, MarlinSparseLayout) - - # Linear layers are (in_features, out_features) but the int_data that is reaching this point - # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. - q_w_24 = int_data.t() - scale_t = scale.t() - - if not torch.cuda.get_device_capability()[0] >= 8: - raise ValueError( - f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' - ) - - if q_w_24.dtype != torch.int32: - raise ValueError("Only `torch.int32` weights are supported.") - - in_features, out_features = q_w_24.shape - if in_features % 128 != 0 or out_features != 256 == 0: - raise ValueError( - "`in_features` must be divisible by 64 and `out_features` by 256." - ) - - # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 - # will require a bit more work to get our current quantization flow to work with it. - # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main - num_bits = 4 if torch.max(q_w_24) < 16 else -1 - if num_bits not in [4]: - raise ValueError( - f"Only {[4]} bits are supported, got {num_bits}." - ) - - group_size = in_features // scale_t.shape[0] - if group_size == 0: - group_size = in_features - assert group_size <= in_features, "Group size must be less than or equal to in_features." - - if group_size not in const.SUPPORTED_GROUP_SIZES: - raise ValueError( - f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." - ) - - # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) - - return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, _layout, q_w_24.shape, - group_size, num_bits - ) - - def get_layout(self) -> Layout: - return self._layout - - def _apply_fn_to_data(self, fn): - self.int_data = fn(self.int_data) - self.scale = fn(self.scale) - self.zero_point = fn(self.zero_point) - self.meta = fn(self.meta) - return self - - -@register_layout(Float8Layout) -class Float8AQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for float8 layout affine quantized tensor - - Note: technically we should not create a new layout for float8 we should merge this into - plain layout - """ - float8_data: torch.Tensor - scale: torch.Tensor - transposed: bool - - def __new__( - cls, - float8_data: torch.Tensor, - scale: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = float8_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout - ) - kwargs["dtype"] = float8_data.dtype - kwargs["requires_grad"] = False - shape = float8_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - float8_data: torch.Tensor, - scale: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.float8_data = float8_data - self.scale = scale - self.transposed = transposed - self._layout = _layout - - def _apply_fn_to_data(self, fn): - """ Applys a fn to all tensor components stored on this class""" - return self.__class__( - fn(self.float8_data), - fn(self.scale), - self.transposed, - self._layout, - ) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.float8_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.transposed, - self._layout, - ) - - def __tensor_flatten__(self): - return ["float8_data", "scale"], [self.transposed, self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, _layout, = tensor_attributes - return cls(float8_data, scale, transposed, _layout) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - elif func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - elif func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - args[0].transposed = not args[0].transposed - return return_and_correct_aliasing(func, args, kwargs, args[0]) - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - #TODO: scale replecation should be dependent on block size - if self.scale.ndim == 1: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) - ) - elif self.scale.ndim == 0: - return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) - ) - else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") - elif dim == 1: - return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) - ) - else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") - else: - raise NotImplementedError( - f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return self.float8_data, self.scale, None - - def get_layout(self) -> Layout: - return self._layout - - @classmethod - def from_plain( - cls, - data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - """ Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" - return cls(data, scale, False, _layout) - - def __repr__(self): - float8_data, scale, _ = self.get_plain() - _layout = self.get_layout() - return (f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})") - - -@register_layout(TensorCoreTiledLayout) -class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, - used by tinygemm kernels `_weight_int4pack_mm` - - It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of - dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] - (unpacked Tensor shape is n * k) - where inner_k_tiles is an internal argument for packing function of tensor core tiled layout - that can affect the performance of the matmul kernel (defaults to 8) - - Note: we also pack scale and zero point together here for tinygemm kernel - - Note: technically tensor core tiled layout should be the layout for the underlying packed weight - (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used - in plain layout, we just created a layout for AQT right now, this could be improved if we split out - int4 aqt into a separate tensor subclass - - fields: - packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout - scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor - """ - - def __new__( - cls, - packed_weight: torch.Tensor, - scale_and_zero: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scale_and_zero: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scale_and_zero = scale_and_zero - self.transposed = False - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, _layout, = tensor_attributes - return cls(packed_weight, scale_and_zero, transposed, _layout) - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout - ): - - assert isinstance(_layout, TensorCoreTiledLayout) - - if TORCH_VERSION_AT_LEAST_2_5: - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" - else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) - scale = scale.reshape(int_data.shape[0], -1) - zero_point = zero_point.reshape(int_data.shape[0], -1) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros - scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero, False, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - device = kwargs["device"] - # tensor core tiled layout supports both cpu and cuda but does not support the conversion - # between these two devices, in the future we should not use the same layout for - # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 - if not is_device(torch.device(self.device).type, device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") - return self.__class__( - self.packed_weight.to(device), - self.scale_and_zero.to(device), - self.transposed, - self._layout, - ) - - def _apply_fn_to_data(self, fn): - # self.packed_weight = fn(self.packed_weight) - # self.scale_and_zero = fn(self.scale_and_zero) - # return self - return self.__class__( - fn(self.packed_weight), - fn(self.scale_and_zero), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.t.default: - """we don't need to repack the weight and just rely on external - shape being changed and record the status of transpose/no-transpose - """ - transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) - return return_and_correct_aliasing(func, args, kwargs, transposed) - - if func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - int_data, scale, zero_point = self.get_plain() - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) - elif dim == 1: - int_data, scale, zero_point = self.get_plain() - assert step == 1, "Only step == 1 is supported in slicing right now" - data_len = int_data.shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - int_data = aten.slice.Tensor(int_data, dim, start, end, step) - # this is to handle padding - int_data = self._layout.post_process(int_data) - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) - sliced = self.from_plain(int_data, scale, zero_point, self._layout) - return sliced - else: - raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") - - raise NotImplementedError( - f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - from torchao.quantization.quant_primitives import ( - ZeroPointDomain, - quantize_affine, - ) - from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros - scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) - - cur_shape = self.shape - assert len(cur_shape) == 4 - inner_k_tiles = cur_shape[-1] * 2 - original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) - eye_shape = original_shape[1] - groupsize = int(original_shape[1] / scale.shape[-2]) - block_size = (1, groupsize) - device = self.device - original_dtype = torch.bfloat16 - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - zero_point_domain = ZeroPointDomain.FLOAT - assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) - dequantized = dequantized.t().contiguous() - # TODO: move this to `unpack_tinygemm_scales_and_zeros`? - scale = scale.reshape(scale.shape[:-1]).contiguous() - zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) - return int_data, scale, zero - - def get_layout(self) -> Layout: - return self._layout - ##################################################### # torch functional and aten operator implementation # ##################################################### -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_tensor_core_tile_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - # TODO: use torch.uint4 - return ( - aqt.tensor_impl.dtype == torch.int32 and - aqt.quant_min == 0 and - aqt.quant_max == 15 - ) - - -implements = AffineQuantizedTensor.implements - -# following are a list of (dispatch_condition, implementation) functions that takes the following args: -# input_tensor: dimension is (M1, M2, ..., in_features) -# weight_tensor: dimension is (out_features, in_features) -# bias: dimension is (out_features,) -# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y - - -def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, SemiSparseLayout) - ) - -def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8 = weight_tensor.tensor_impl.int_data - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm - y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t() - y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( - *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] - ) - output_dtype = input_tensor.dtype - # TODO: waiting for jesse's test/fix - y = y.to(output_dtype).contiguous() - if bias is not None: - y += bias - return y - -def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, BlockSparseLayout) - ) - - -def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals = weight_tensor.tensor_impl - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - tmp_t = tmp.t() - - y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1)) - y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) - y = y.reshape(*y_shape) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y - - -def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.dtype == torch.bfloat16 and - # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - weight_tensor.dtype == torch.bfloat16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor._layout, TensorCoreTiledLayout) - ) - - -def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - # TODO: check groupsize quantization - # avoid circular dep, TODO: move this to a common util.py - act_mat = input_tensor - # weight is packed from padded (out_features, in_features) weight tensor - # (same dimension requirement as F.linear weight) - packed_weight = weight_tensor.tensor_impl.packed_weight - scale_and_zero = weight_tensor.tensor_impl.scale_and_zero - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape and pad activation - act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) - pad_size = find_multiple(act_mat.shape[-1], 1024) - act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) - - # groupwise int4 quantization - groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - - if bias is not None: - y += bias - return y.to(orig_dtype) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - -def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import FloatxTensorCoreLayout - return ( - # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype in (torch.float16, torch.bfloat16) and - # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and - ( - # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 3) or - # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 1) - ) - ) - -def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import _SPLIT_K_MAP - from torchao.ops import quant_llm_linear - - act = input_tensor - weight = weight_tensor - - out_dim, in_dim = weight.shape - act_reshaped = act.view(-1, in_dim) - - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - bsize = act_reshaped.shape[0] - splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - out = quant_llm_linear( - weight._layout.ebits, - weight._layout.mbits, - act_reshaped, - weight.tensor_impl.packed_floatx_data, - weight.tensor_impl.scale, - splitK=splitK, - ) - - if bias is not None: - out += bias - - return out.view(*act.shape[:-1], out_dim).to(act.dtype) - -def _linear_fp8_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, AffineQuantizedTensor], - weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], - bias: Optional[torch.Tensor], -) -> bool: - def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: - return ( - isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt._layout, Float8Layout) - and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) - ) - return check_aqt(input_tensor) and check_aqt(weight_tensor) - - -def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """ Ensures input tensor is correctly formated for _scaled_mm """ - input_scale = input_scale.unsqueeze(-1) - - if input_scale.dim() > 2: - input_scale = input_scale.reshape(-1, input_scale.shape[-1]) - - return input_scale - -def _linear_fp8_act_fp8_weight_impl( - input_tensor: AffineQuantizedTensor, - weight_tensor: AffineQuantizedTensor, - bias: Optional[torch.Tensor], -): - """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" - scaled_mm_config = weight_tensor._layout.mm_config - out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) - - # Weight tensor preprocessing - w_tensor_impl = weight_tensor.tensor_impl - assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" - w_data = w_tensor_impl.float8_data - w_scale = w_tensor_impl.scale - - # Input tensor preprocessing - inpt_data = input_tensor.tensor_impl.float8_data - input_scale = input_tensor.tensor_impl.scale - # Handle case where input tensor is more than 2D - inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) - - # Handle rowwise case - if _is_rowwise_scaled(weight_tensor): - assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" - w_scale = w_scale.unsqueeze(-1).T - input_scale = preprocess_scale(input_scale, input_tensor.shape) - - # Preprocess data - inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) - - # Perform the computation - return addmm_float8_unwrapped_inference( - inpt_data, - input_scale, - w_data, - w_scale, - output_dtype=input_tensor.dtype, - bias=bias, - use_fast_accum=scaled_mm_config.use_fast_accum, - ).reshape(out_shape) - -def _linear_fp_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, AffineQuantizedTensor], - weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], - bias: Optional[torch.Tensor], -) -> bool: - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, Float8Layout) - and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) - ) - -def _linear_fp_act_fp8_weight_impl( - input_tensor: torch.Tensor, - weight_tensor: AffineQuantizedTensor, - bias: Optional[torch.Tensor], -): - return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) - -def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): - return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - input_tensor.dtype == torch.float16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, MarlinSparseLayout) - ) - -def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace, const - from torchao.ops import marlin_24_gemm - - assert isinstance(weight_tensor, AffineQuantizedTensor) - - sparse_w_int4 = weight_tensor.tensor_impl.int_data - scale = weight_tensor.tensor_impl.scale - meta = weight_tensor.tensor_impl.meta - original_shape = weight_tensor.tensor_impl.original_shape - num_bits = weight_tensor.tensor_impl.num_bits - - # Folds batch dimension into the first dimension - input_2d = input_tensor.view(-1, input_tensor.shape[-1]) - - size_m = input_2d.shape[0] - size_n = scale.shape[1] - size_k = input_2d.shape[1] - workspace_24 = marlin_24_workspace(original_shape[1]) - - out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, - workspace_24, num_bits, size_m, size_n, size_k - ) - - # Unfold the batch dimension - out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) - - if bias is not None: - out += bias.to(out.dtype) - return out - - -def _register_aqt_quantized_linear_dispatches(): - for dispatch_condition, impl in [ - (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), - (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), - (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), - (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), - (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), - (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), - ]: - register_aqt_quantized_linear_dispatch(dispatch_condition, impl) - -_register_aqt_quantized_linear_dispatches() - -@implements([torch.nn.functional.linear, aten.linear.default]) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") - - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - -@implements(torch.nn.functional.embedding) -def _(func, types, args, kwargs): - # new_arg1 = args[1].dequantize() - # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 - idx = args[0] - int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] - # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so - # we need to increase block size to correct dim - new_blocks = idx.dim()-1 - return dequantize_affine( - sliced_data, - new_blocks*[1]+list(args[1].block_size), - sliced_scale, - sliced_zero_point, - sliced_data.dtype, - args[1].quant_min, - args[1].quant_max, - args[1].zero_point_domain, - output_dtype=sliced_scale.dtype, - ) - -@implements(aten.addmm.default) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[1], - args[2], - args[0], - ) - if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") - - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - weight_tensor = weight_tensor.t() - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return func(bias, input_tensor, weight_tensor) - -@implements(aten.mm.default) -def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None - ) - if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") - - try: - weight_tensor = weight_tensor.t() - return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) - except QuantizedLinearNotImplementedError as e: - # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: - raise e - - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, AffineQuantizedTensor): - weight_tensor = weight_tensor.dequantize() - return func(input_tensor, weight_tensor) - -@implements(aten.detach.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements(aten.clone.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements(aten._to_copy.default) -def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - -@implements(aten.t.default) -def _(func, types, args, kwargs): - block_size = args[0].block_size - assert len(block_size) == 2 - transposed_block_size = (block_size[1], block_size[0]) - tensor = args[0] - shape = tensor.shape[::-1] - new = tensor.__class__( - tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() - ) - return return_and_correct_aliasing(func, args, kwargs, new) - -@implements(aten.slice.Tensor) -def _(func, types, args, kwargs): - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" - if end >= self.shape[dim]: - end = self.shape[dim] - shape = list(self.shape) - shape[dim] = end - start - block_size = self.block_size - assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" - # with slice, some shape dimension might be smaller than block_size dimension, so - # we need to make sure there is no overflow - block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - return return_and_correct_aliasing(func, args, kwargs, new) - -# this is needed for DTensor.from_local() and for flattening tensor -@implements(aten.view.default) -def _(func, types, args, kwargs): - self, shape = args - - if tuple(self.shape) == tuple(shape): - return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - - if len(shape) == 1 and shape[0] == -1: - assert len(self.block_size) == 2 and self.block_size[0] == 1 - block_size = (self.block_size[1],) - return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - - raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") - - to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py new file mode 100644 index 0000000000..a3d995d653 --- /dev/null +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -0,0 +1,295 @@ +import logging +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.affine_quantized_tensor import * +# from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor +from torchao.dtypes.floatx.float8_layout import ( + _linear_fp8_act_fp8_weight_check, + _linear_fp8_act_fp8_weight_impl, + _linear_fp_act_fp8_weight_check, + _linear_fp_act_fp8_weight_impl, +) +from torchao.dtypes.floatx.floatx_tensor_core_layout import ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, +) +from torchao.dtypes.uintx.block_sparse_layout import ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, +) +from torchao.dtypes.uintx.marlin_sparse_layout import ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, +) +from torchao.dtypes.uintx.semi_sparse_layout import ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import ( + _linear_bf16_act_uint4_weight_check, + _linear_bf16_act_uint4_weight_impl, +) +from torchao.dtypes.uintx.uint8_layout import ( + _linear_int8_act_int8_weight_check, + _linear_int8_act_int8_weight_impl, + _linear_fp_act_int8_weight_check, + _linear_fp_act_int8_weight_impl, +) +from torchao.utils import ( + fill_defaults, +) + +logger = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +_AQT_QLINEAR_DISPATCH_TABLE = {} + + +def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): + """Register a dispatch for quantized linear op with dispatch_condition function and impl function + both takes three arguments: + input_tensor: dimension is (M1, M2, ..., in_features) + weight_tensor: dimension is (out_features, in_features) + bias: dimension is (out_features,) + so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + + Args: + `dispatch_condition` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], bool]: the dispatch + condition for a specialized quantized linear implementation, e.g. bfloat16 activation + uint4 weight + `impl` (Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: the specialized + quantized linear implementation + """ + _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + + +def deregister_aqt_quantized_linear_dispatch(dispatch_condition): + if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: + del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] + else: + logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + + +class QuantizedLinearNotImplementedError(NotImplementedError): + """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + pass + + +@staticmethod +def _quantized_linear_op(input_tensor, weight_tensor, bias): + for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): + if dispatch_condition(input_tensor, weight_tensor, bias): + return impl(input_tensor, weight_tensor, bias) + raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") + + +# Attach the _quantized_linear_op to the AffineQuantizedTensor class +AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op + + + +# # following are a list of (dispatch_condition, implementation) functions that takes the following args: +# # input_tensor: dimension is (M1, M2, ..., in_features) +# # weight_tensor: dimension is (out_features, in_features) +# # bias: dimension is (out_features,) +# # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches +def _register_aqt_quantized_linear_dispatches(): + for dispatch_condition, impl in [ + (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), + (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), + (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), + (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), + (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), + (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), + (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), + (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + ]: + register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + + +_register_aqt_quantized_linear_dispatches() + +implements = AffineQuantizedTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError(f"{func} is not implemented for non floating point input") + + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(torch.nn.functional.embedding) +def _(func, types, args, kwargs): + # new_arg1 = args[1].dequantize() + # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) + assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 + idx = args[0] + int_data, scale, zero_point = args[1].tensor_impl.get_plain() + + sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] + # Block size is expecting 2 dimensions [1, group size] but + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # we need to increase block size to correct dim + new_blocks = idx.dim()-1 + return dequantize_affine( + sliced_data, + new_blocks*[1]+list(args[1].block_size), + sliced_scale, + sliced_zero_point, + sliced_data.dtype, + args[1].quant_min, + args[1].quant_max, + args[1].zero_point_domain, + output_dtype=sliced_scale.dtype, + ) + + +@implements(aten.addmm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[1], + args[2], + args[0], + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError(f"{func} is not implemented for non floating point input") + + # using try/except here so that we can have a general fallback when input_tensor/weight_tensor + # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to + # make the branches easier to understand in `_quantized_linear_op` + try: + weight_tensor = weight_tensor.t() + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) + + +@implements(aten.mm.default) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + None + ) + if not input_tensor.is_floating_point(): + raise NotImplementedError(f"{func} is not implemented for non floating point input") + + try: + weight_tensor = weight_tensor.t() + return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) + except QuantizedLinearNotImplementedError as e: + # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` + if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + raise e + + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(input_tensor, weight_tensor) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + block_size = args[0].block_size + assert len(block_size) == 2 + transposed_block_size = (block_size[1], block_size[0]) + tensor = args[0] + shape = tensor.shape[::-1] + new = tensor.__class__( + tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + assert step == 1 + assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + if end >= self.shape[dim]: + end = self.shape[dim] + shape = list(self.shape) + shape[dim] = end - start + block_size = self.block_size + assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + # with slice, some shape dimension might be smaller than block_size dimension, so + # we need to make sure there is no overflow + block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) + new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return return_and_correct_aliasing(func, args, kwargs, new) + + +# this is needed for DTensor.from_local() and for flattening tensor +@implements(aten.view.default) +def _(func, types, args, kwargs): + self, shape = args + + if tuple(self.shape) == tuple(shape): + return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + if len(shape) == 1 and shape[0] == -1: + assert len(self.block_size) == 2 and self.block_size[0] == 1 + block_size = (self.block_size[1],) + return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + + raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index d7559015f4..34b7ec1f91 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1,2 @@ -from .floatx import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx_tensor_core_layout import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .float8_layout import Float8AQTTensorImpl, Float8Layout diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py new file mode 100644 index 0000000000..1c3c046497 --- /dev/null +++ b/torchao/dtypes/floatx/float8_layout.py @@ -0,0 +1,261 @@ +import torch +from torchao.utils import _is_float8_type +from torchao.dtypes.utils import Layout, AQTTensorImpl +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout +) +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, + is_traceable_wrapper_subclass, +) +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Float8Layout(Layout): + mm_config: Optional[Float8MMConfig] = None + +@register_layout(Float8Layout) +class Float8AQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for float8 layout affine quantized tensor + + Note: technically we should not create a new layout for float8 we should merge this into + plain layout + """ + float8_data: torch.Tensor + scale: torch.Tensor + transposed: bool + + def __new__( + cls, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = float8_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout + ) + kwargs["dtype"] = float8_data.dtype + kwargs["requires_grad"] = False + shape = float8_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + float8_data: torch.Tensor, + scale: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.float8_data = float8_data + self.scale = scale + self.transposed = transposed + self._layout = _layout + + def _apply_fn_to_data(self, fn): + """ Applys a fn to all tensor components stored on this class""" + return self.__class__( + fn(self.float8_data), + fn(self.scale), + self.transposed, + self._layout, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.float8_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.transposed, + self._layout, + ) + + def __tensor_flatten__(self): + return ["float8_data", "scale"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] + transposed, _layout, = tensor_attributes + return cls(float8_data, scale, transposed, _layout) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + elif func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + elif func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + #TODO: scale replecation should be dependent on block size + if self.scale.ndim == 1: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + ) + elif self.scale.ndim == 0: + return return_and_correct_aliasing( + func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + ) + else: + raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") + elif dim == 1: + return return_and_correct_aliasing( + func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) + ) + else: + raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + else: + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return self.float8_data, self.scale, None + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + """ Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + return cls(data, scale, False, _layout) + + def __repr__(self): + float8_data, scale, _ = self.get_plain() + _layout = self.get_layout() + return (f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})") + + +########################## +# Float8 Dispatch Kernels +########################## + +def _linear_fp8_act_fp8_weight_check( + input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + bias: Optional[torch.Tensor], +) -> bool: + def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: + return ( + isinstance(aqt, AffineQuantizedTensor) and + isinstance(aqt._layout, Float8Layout) + and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) + ) + return check_aqt(input_tensor) and check_aqt(weight_tensor) + + +def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): + """ Ensures input tensor is correctly formated for _scaled_mm """ + input_scale = input_scale.unsqueeze(-1) + + if input_scale.dim() > 2: + input_scale = input_scale.reshape(-1, input_scale.shape[-1]) + + return input_scale + +def _linear_fp8_act_fp8_weight_impl( + input_tensor: 'AffineQuantizedTensor', + weight_tensor: 'AffineQuantizedTensor', + bias: Optional[torch.Tensor], +): + """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" + scaled_mm_config = weight_tensor._layout.mm_config + out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) + + # Weight tensor preprocessing + w_tensor_impl = weight_tensor.tensor_impl + assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" + w_data = w_tensor_impl.float8_data + w_scale = w_tensor_impl.scale + + # Input tensor preprocessing + inpt_data = input_tensor.tensor_impl.float8_data + input_scale = input_tensor.tensor_impl.scale + # Handle case where input tensor is more than 2D + inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) + + # Handle rowwise case + if _is_rowwise_scaled(weight_tensor): + assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + w_scale = w_scale.unsqueeze(-1).T + input_scale = preprocess_scale(input_scale, input_tensor.shape) + + # Preprocess data + inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) + + # Perform the computation + return addmm_float8_unwrapped_inference( + inpt_data, + input_scale, + w_data, + w_scale, + output_dtype=input_tensor.dtype, + bias=bias, + use_fast_accum=scaled_mm_config.use_fast_accum, + ).reshape(out_shape) + +def _linear_fp_act_fp8_weight_check( + input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + bias: Optional[torch.Tensor], +) -> bool: + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + # weight is float8 quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + isinstance(weight_tensor._layout, Float8Layout) + and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] + and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) + ) + +def _linear_fp_act_fp8_weight_impl( + input_tensor: torch.Tensor, + weight_tensor: 'AffineQuantizedTensor', + bias: Optional[torch.Tensor], +): + return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py similarity index 87% rename from torchao/dtypes/floatx/floatx.py rename to torchao/dtypes/floatx/floatx_tensor_core_layout.py index a4745e9315..cfdb566279 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -3,14 +3,20 @@ import torch from torch import Tensor -from torch.utils._python_dispatch import return_and_correct_aliasing +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, + is_traceable_wrapper_subclass, +) from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones from torchao.dtypes.utils import ( Layout, + AQTTensorImpl, +) +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, ) -from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass -from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout aten = torch.ops.aten @@ -360,6 +366,7 @@ class FloatxTensorCoreLayout(Layout): ebits: int mbits: int + @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -491,3 +498,56 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_layout(self) -> Layout: return self._layout + + +def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import FloatxTensorCoreLayout + return ( + # input is native float32 tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + input_tensor.dtype in (torch.float16, torch.bfloat16) and + # weight is floatx Tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and + ( + # weight is using fp6 quantization + (weight_tensor._layout.ebits == 3 and + weight_tensor._layout.mbits == 2) or + (weight_tensor._layout.ebits == 2 and + weight_tensor._layout.mbits == 3) or + # weight is using fp5 quantization + (weight_tensor._layout.ebits == 2 and + weight_tensor._layout.mbits == 2) or + (weight_tensor._layout.ebits == 3 and + weight_tensor._layout.mbits == 1) + ) + ) + +def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): + from torchao.dtypes.floatx import _SPLIT_K_MAP + from torchao.ops import quant_llm_linear + + act = input_tensor + weight = weight_tensor + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim) + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight._layout.ebits, + weight._layout.mbits, + act_reshaped, + weight.tensor_impl.packed_floatx_data, + weight.tensor_impl.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index c44803f6d2..7b2c4e9028 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1 +1,6 @@ from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .uint4 import UInt4Tensor +from .block_sparse_layout import BlockSparseLayout +from .semi_sparse_layout import SemiSparseLayout +from .marlin_sparse_layout import MarlinSparseLayout +from .tensor_core_tiled_layout import TensorCoreTiledLayout diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py new file mode 100644 index 0000000000..4f6358fae5 --- /dev/null +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -0,0 +1,204 @@ +from dataclasses import dataclass +import logging +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, + PlainAQTTensorImpl +) +from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + +@dataclass(frozen=True) +class BlockSparseLayout(Layout): + blocksize: int = 64 + +@register_layout(BlockSparseLayout) +class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): + bsr_crow_indices: Optional[torch.Tensor] + bsr_col_indices: Optional[torch.Tensor] + bsr_values: Optional[torch.Tensor] + scale: Optional[torch.Tensor] + zero_point: Optional[torch.Tensor] + + __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + + @staticmethod + def __new__( # noqa: PYI034 + cls, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + if bsr_values is None: + raise ValueError("bsr values must be provided!") + else: + previous_tensor = bsr_values + + kwargs = { + "device": previous_tensor.device, + "dtype": previous_tensor.dtype, + "layout": previous_tensor.layout, + "requires_grad": requires_grad, + } + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( # noqa: PYI034 + self, + shape: torch.Size, + bsr_crow_indices: Optional[torch.Tensor], + bsr_col_indices: Optional[torch.Tensor], + bsr_values: Optional[torch.Tensor], + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + _layout: Layout, + requires_grad: bool = False, + ): + self.bsr_crow_indices = bsr_crow_indices + self.bsr_col_indices = bsr_col_indices + self.bsr_values = bsr_values + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + inner_tensors = list( + filter(lambda x: getattr(self, x) is not None, self.__slots__) + ) + tensor_meta = (self.shape, self._layout, self.requires_grad) + return inner_tensors, tensor_meta + + @classmethod + def __tensor_unflatten__( + cls, + inner_tensors, + tensor_meta: Tuple[torch.Size, bool], + outer_size, + outer_stride, + ) -> torch.Tensor: + shape, _layout, requires_grad = tensor_meta + return cls( + shape=shape, + bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), + bsr_col_indices=inner_tensors.get("bsr_col_indices", None), + bsr_values=inner_tensors.get("bsr_values", None), + scale=inner_tensors.get("scale", None), + zero_point=inner_tensors.get("zero_point", None), + _layout=_layout, + requires_grad=requires_grad, + ) + + @classmethod + def from_plain(cls, int_data, scale, zero_point, _layout): + bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) + return cls( + shape=int_data.shape, + bsr_crow_indices=bsr_tensor.crow_indices(), + bsr_col_indices=bsr_tensor.col_indices(), + bsr_values=bsr_tensor.values(), + scale=scale, + zero_point=zero_point, + _layout = _layout, + requires_grad=False, + ) + + def get_plain(self): + int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + return int_data_expanded, self.scale, self.zero_point + + def _apply_fn_to_data(self, func): + return self.__class__( + shape = self.shape, + bsr_crow_indices=func(self.bsr_crow_indices), + bsr_col_indices=func(self.bsr_col_indices), + bsr_values=func(self.bsr_values), + scale=self.scale, + zero_point=self.zero_point, + _layout=self._layout, + requires_grad=self.requires_grad, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + # Need the following for bsr specific functions + if func is aten.crow_indices.default: + return args[0].bsr_crow_indices.detach() + + if func is aten.col_indices.default: + return args[0].bsr_col_indices.detach() + + if func is aten.values.default: + return args[0].bsr_values.detach() + + if func is aten._nnz.default: + return args[0].bsr_values.shape[0] + + raise NotImplementedError( + f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + + +def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, BlockSparseLayout) + ) + + +def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals = weight_tensor.tensor_impl + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + tmp_t = tmp.t() + + y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1)) + y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) + y = y.reshape(*y_shape) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py new file mode 100644 index 0000000000..5483e4d7a3 --- /dev/null +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -0,0 +1,240 @@ +from dataclasses import dataclass +from torchao.dtypes.utils import Layout, AQTTensorImpl +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout +) +import torch +from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4 + + +def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): + return ( + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_tensor_core_tile_uint4(weight_tensor) and + input_tensor.dtype == torch.float16 and + len(weight_tensor.shape) == 2 and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor._layout, MarlinSparseLayout) + ) + +def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): + from torchao.sparsity.marlin import marlin_24_workspace, const + from torchao.ops import marlin_24_gemm + + assert isinstance(weight_tensor, AffineQuantizedTensor) + + sparse_w_int4 = weight_tensor.tensor_impl.int_data + scale = weight_tensor.tensor_impl.scale + meta = weight_tensor.tensor_impl.meta + original_shape = weight_tensor.tensor_impl.original_shape + num_bits = weight_tensor.tensor_impl.num_bits + + # Folds batch dimension into the first dimension + input_2d = input_tensor.view(-1, input_tensor.shape[-1]) + + size_m = input_2d.shape[0] + size_n = scale.shape[1] + size_k = input_2d.shape[1] + workspace_24 = marlin_24_workspace(original_shape[1]) + + out = marlin_24_gemm( + input_2d, sparse_w_int4, meta, scale, + workspace_24, num_bits, size_m, size_n, size_k + ) + + # Unfold the batch dimension + out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out + +@dataclass(frozen=True) +class MarlinSparseLayout(Layout): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. + - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format + - 2º: tensor is injected with 2:4 sparsity + - 3º: transposes it again because the quantization process will compute the scales for dim=-1 + + Args: + input (torch.Tensor): the input tensor to preprocess + + Returns: + torch.Tensor: the preprocessed tensor + """ + from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() + w_24, _ = inject_24(input_t, *input_t.shape) + return w_24.t() + +@register_layout(MarlinSparseLayout) +class MarlinSparseAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for sparse_marlin_24 layout for affine quantized tensor. + + Can be used with 4 bits and 8 bits quantization. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Sparse marlin documentation and information: + https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + meta: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.meta = meta + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + scale = tensor_data_dict["scale"] + zero_point = tensor_data_dict["zero_point"] + meta = tensor_data_dict["meta"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) + + def get_plain(self): + from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( + self.int_data, + self.scale, + self.meta, + self.original_shape, + self.group_size, + self.num_bits, + ) + int_data_expanded_t = int_data_expanded.t() + scales_expanded_t = scales_expanded.t() + return int_data_expanded_t, scales_expanded_t, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + assert isinstance(_layout, MarlinSparseLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w_24 = int_data.t() + scale_t = scale.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + ) + + if q_w_24.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w_24.shape + if in_features % 128 != 0 or out_features != 256 == 0: + raise ValueError( + "`in_features` must be divisible by 64 and `out_features` by 256." + ) + + # NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8 + # will require a bit more work to get our current quantization flow to work with it. + # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main + num_bits = 4 if torch.max(q_w_24) < 16 else -1 + if num_bits not in [4]: + raise ValueError( + f"Only {[4]} bits are supported, got {num_bits}." + ) + + group_size = in_features // scale_t.shape[0] + if group_size == 0: + group_size = in_features + assert group_size <= in_features, "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin 2:4 format + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + + return cls( + marlin_24_q_w_comp, marlin_24_s, zero_point, + meta, _layout, q_w_24.shape, + group_size, num_bits + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.scale = fn(self.scale) + self.zero_point = fn(self.zero_point) + self.meta = fn(self.meta) + return self diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py new file mode 100644 index 0000000000..31252701b5 --- /dev/null +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from torchao.dtypes.utils import Layout, PlainLayout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, + PlainAQTTensorImpl +) +import torch +from typing import Optional +from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range + +def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + weight_tensor.is_cuda and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, SemiSparseLayout) + ) + + +def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8 = weight_tensor.tensor_impl.int_data + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm + y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( + w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, + ).t() + y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( + *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] + ) + output_dtype = input_tensor.dtype + # TODO: waiting for jesse's test/fix + y = y.to(output_dtype).contiguous() + if bias is not None: + y += bias + return y + +@dataclass(frozen=True) +class SemiSparseLayout(Layout): + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + # prune to 2:4 if not already + temp = input.detach() + pruning_inds = temp.abs().view(-1, 4).argsort(dim=1)[:, :2] + temp.view(-1, 4).scatter_(1, pruning_inds, value=0) + return temp + + + +@register_layout(SemiSparseLayout) +class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): + """ + TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor + """ + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"SparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def get_plain(self): + # Currently we don't have cuSPARSELt expansion routines, so we matmul by + # the identity matrix to get the original dense matrix. This is slow though. + cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) + int_data_expanded = torch._cslt_sparse_mm(self.int_data, + torch.eye(cols, + dtype=self.int_data.dtype, + device=self.int_data.device).t()) + return int_data_expanded, self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, SemiSparseLayout) + int_data_compressed = torch._cslt_compress(int_data) + return cls(int_data_compressed, scale, zero_point, _layout) diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py new file mode 100644 index 0000000000..1f6bb92179 --- /dev/null +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -0,0 +1,323 @@ +import torch +from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5 +from torchao.dtypes.utils import Layout, AQTTensorImpl +from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor, register_layout +from dataclasses import dataclass +from typing import Optional, Tuple +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, + is_traceable_wrapper_subclass, +) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain +) + +aten = torch.ops.aten + +def _aqt_is_tensor_core_tile_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.tensor_impl.dtype == torch.int32 and + aqt.quant_min == 0 and + aqt.quant_max == 15 + ) + +def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native bfloat16 tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.dtype == torch.bfloat16 and + # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_tensor_core_tile_uint4(weight_tensor) and + weight_tensor.dtype == torch.bfloat16 and + len(weight_tensor.shape) == 2 and + weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and + isinstance(weight_tensor._layout, TensorCoreTiledLayout) + ) + + +def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): + assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + # TODO: check groupsize quantization + # avoid circular dep, TODO: move this to a common util.py + act_mat = input_tensor + # weight is packed from padded (out_features, in_features) weight tensor + # (same dimension requirement as F.linear weight) + packed_weight = weight_tensor.tensor_impl.packed_weight + scale_and_zero = weight_tensor.tensor_impl.scale_and_zero + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape and pad activation + act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16) + pad_size = find_multiple(act_mat.shape[-1], 1024) + act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) + + # groupwise int4 quantization + groupsize = weight_tensor.block_size[1] + y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + + if bias is not None: + y += bias + return y.to(orig_dtype) + +@dataclass(frozen=True) +class TensorCoreTiledLayout(Layout): + """ + inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel + """ + inner_k_tiles: int = 8 + + def pre_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + input = self.pre_process(input) + orig_qparam_shape = scale.shape + new_qparam_shape, reduction_dims = _get_reduction_params(block_size, input.size()) + for dim in reduction_dims: + new_qparam_shape.pop(dim) + change_in_qparam_shape = [new_dim_size-orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape)] + padding_changes=[] + for dim_change in change_in_qparam_shape: + padding_changes = [0, dim_change] + padding_changes + scale = torch.nn.functional.pad(scale, padding_changes) + zero_point = torch.nn.functional.pad(zero_point, padding_changes) + return input, scale, zero_point + + def post_process(self, input: torch.Tensor) -> torch.Tensor: + orig_out_features, orig_in_features = input.shape + in_features = find_multiple(orig_in_features, 1024) + out_features = find_multiple(orig_out_features, 8) + input = torch.nn.functional.pad( + input, + (0, in_features - orig_in_features, 0, out_features - orig_out_features), + ) + return input + + def extra_repr(self): + return f"inner_k_tiles={self.inner_k_tiles}" + + +@register_layout(TensorCoreTiledLayout) +class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm` + + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of + dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] + (unpacked Tensor shape is n * k) + where inner_k_tiles is an internal argument for packing function of tensor core tiled layout + that can affect the performance of the matmul kernel (defaults to 8) + + Note: we also pack scale and zero point together here for tinygemm kernel + + Note: technically tensor core tiled layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + + fields: + packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + transposed, _layout, = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout + ): + + assert isinstance(_layout, TensorCoreTiledLayout) + + if TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + else: + assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + # tensor core tiled layout supports both cpu and cuda but does not support the conversion + # between these two devices, in the future we should not use the same layout for + # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 + if not is_device(torch.device(self.device).type, device): + raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + # self.packed_weight = fn(self.packed_weight) + # self.scale_and_zero = fn(self.scale_and_zero) + # return self + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 4 + inner_k_tiles = cur_shape[-1] * 2 + original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uintx/uint4.py similarity index 100% rename from torchao/dtypes/uint4.py rename to torchao/dtypes/uintx/uint4.py diff --git a/torchao/dtypes/uintx/uint8.py b/torchao/dtypes/uintx/uint8.py new file mode 100644 index 0000000000..8d53e93e74 --- /dev/null +++ b/torchao/dtypes/uintx/uint8.py @@ -0,0 +1,114 @@ +import torch +from torchao.dtypes.utils import PlainLayout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, +) +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, + is_traceable_wrapper_subclass, +) + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 and + (aqt.quant_min is None or aqt.quant_min == -128) and + (aqt.quant_max is None or aqt.quant_max == 127) + ) + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.tensor_impl.dtype == torch.int8 and + aqt.quant_min == -127 and + (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + # weight is int8 per channel quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_int8(weight_tensor) and + len(weight_tensor.shape) == 2 and + len(weight_tensor.block_size) == 2 and + weight_tensor.block_size[0] == 1 and + weight_tensor.block_size[1] == weight_tensor.shape[1] and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # TODO: enable cpu and mps efficient path + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + + # per channel int8 weight only quantizated mm + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale + orig_dtype = input_tensor.dtype + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias.to(m.dtype) + return y + + +def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, PlainLayout) + ) + +def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/uint8_layout.py b/torchao/dtypes/uintx/uint8_layout.py new file mode 100644 index 0000000000..6be1bc25ee --- /dev/null +++ b/torchao/dtypes/uintx/uint8_layout.py @@ -0,0 +1,118 @@ +import torch +from torchao.dtypes.utils import PlainLayout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, +) +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from torchao.float8.inference import ( + preprocess_data, + Float8MMConfig, + addmm_float8_unwrapped_inference, + _is_rowwise_scaled +) +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, + is_traceable_wrapper_subclass, +) +from torchao.quantization.quant_primitives import ( + int_scaled_matmul, + ZeroPointDomain, +) + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 and + (aqt.quant_min is None or aqt.quant_min == -128) and + (aqt.quant_max is None or aqt.quant_max == 127) + ) + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.tensor_impl.dtype == torch.int8 and + aqt.quant_min == -127 and + (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) and + input_tensor.is_floating_point() and + # weight is int8 per channel quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) and + _aqt_is_int8(weight_tensor) and + len(weight_tensor.shape) == 2 and + len(weight_tensor.block_size) == 2 and + weight_tensor.block_size[0] == 1 and + weight_tensor.block_size[1] == weight_tensor.shape[1] and + weight_tensor.zero_point_domain == ZeroPointDomain.INT and + isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # TODO: enable cpu and mps efficient path + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + + # per channel int8 weight only quantizated mm + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale + orig_dtype = input_tensor.dtype + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias.to(m.dtype) + return y + + +def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) and + _aqt_is_int8_reduced_range(input_tensor) and + isinstance(weight_tensor, AffineQuantizedTensor) and + input_tensor.dtype == weight_tensor.dtype and + isinstance(input_tensor._layout, PlainLayout) and + isinstance(weight_tensor._layout, PlainLayout) + ) + +def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index a48faee8dc..4cfcd6696b 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -8,7 +8,10 @@ Layout, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout +from torchao.dtypes.affine_quantized_tensor import ( + register_layout, + PlainAQTTensorImpl +) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 7c0dfd9dc8..01725f7805 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,6 +1,7 @@ import torch from typing import Union, Tuple from dataclasses import dataclass +from torchao.utils import TorchAOBaseTensor """ Base class for different layout, following the same design of PyTorch layout @@ -58,3 +59,42 @@ def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[in out_dim = weight_shape[0] inpt_dims = input_shape[:-1] return (*inpt_dims, out_dim) + + +############################### +# Base Tensor Impl Subclass # +############################### +class AQTTensorImpl(TorchAOBaseTensor): + """ + Base class for the tensor impl for `AffineQuantizedTensor` + + Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct + the underlying implementation of a AQT based on layout + """ + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Get the plain (unpacked) Tensor for the tensor impl + + Returns data, scale and zero_point + Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors + """ + pass + + def get_layout(self) -> Layout: + pass + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + """ Construct a TensorImpl from data, scale, zero_point and the _layout""" + pass + + def __repr__(self): + data, scale, zero_point = self.get_plain() + _layout = self.get_layout() + return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" + diff --git a/torchao/sparsity/utils.py b/torchao/sparsity/utils.py index 4b5164863f..011042828e 100644 --- a/torchao/sparsity/utils.py +++ b/torchao/sparsity/utils.py @@ -120,3 +120,4 @@ def mask_creator( mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) return mask + From 0cfd2b9a003d232f445f21aeb666ca3fb80fae20 Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 5 Nov 2024 18:06:08 -0800 Subject: [PATCH 02/11] Ruff format and lint (#1226) Ruff format and lint on some high traffic files --- ruff.toml | 4 +- torchao/dtypes/__init__.py | 3 +- torchao/dtypes/affine_quantized_tensor.py | 150 ++++++++++++--- .../floatx/floatx_tensor_core_layout.py | 176 ++++++++++++++---- torchao/dtypes/uintx/bitpacking.py | 116 ++++++++---- torchao/dtypes/uintx/uint4.py | 3 +- torchao/dtypes/uintx/uintx.py | 54 ++++-- torchao/dtypes/utils.py | 18 +- 8 files changed, 393 insertions(+), 131 deletions(-) diff --git a/ruff.toml b/ruff.toml index 1a4a5ff097..773497eb5c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,6 +11,6 @@ include = [ "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", "test/dtypes/test_affine_quantized_float.py", - "torchao/quantization/weight_tensor_linear_activation_quantization.py" - + "torchao/quantization/weight_tensor_linear_activation_quantization.py", + "torchao/dtypes/**/*.py", ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 032948cd51..d7cd517650 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -14,6 +14,7 @@ from .affine_quantized_tensor_ops import * from .utils import ( Layout, + MarlinSparseLayout, PlainLayout, ) from .floatx import ( @@ -28,7 +29,7 @@ __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor" + "UInt4Tensor", "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 31f775e92e..cb2527076d 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -116,9 +116,16 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype = self.dtype from torchao.dtypes.floatx import FloatxTensorCoreLayout + if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx(int_data, scale, self._layout.ebits, self._layout.mbits, output_dtype=output_dtype) + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) else: data, scale, zero_point = self.tensor_impl.get_plain() dq = dequantize_affine( @@ -139,14 +146,23 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor return dq def __tensor_flatten__(self): - return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["tensor_impl"], [ + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + self.dtype, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): tensor_impl = tensor_data_dict["tensor_impl"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = ( + tensor_attributes + ) return cls( tensor_impl, block_size, @@ -179,20 +195,58 @@ def from_hp_to_intx( input_float = _layout.pre_process(input_float) if use_hqq: - assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization." + assert ( + zero_point_domain == ZeroPointDomain.FLOAT + and mapping_type == MappingType.ASYMMETRIC + and quant_min == 0 + ), "Invalid input parameters for HQQ quantization." nbits = int(math.log2(quant_max + 1)) - axis = 1 if (block_size[0]==1) else 0 + axis = 1 if (block_size[0] == 1) else 0 group_size = max(block_size) - compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype + compute_dtype = ( + zero_point_dtype + if (zero_point_dtype is not None) + else input_float.dtype + ) device = input_float.device - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( + input_float, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=compute_dtype, + device=device, + verbose=False, + raw_output=False, + ) data = data.to(target_dtype) else: - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain is None: zero_point = None - data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) # Note: output will be uint8 tensor for sub byte tensors for now data = _layout.post_process(data) @@ -205,7 +259,7 @@ def from_hp_to_intx( quant_min, quant_max, zero_point_domain, - dtype=input_float.dtype + dtype=input_float.dtype, ) @classmethod @@ -222,12 +276,27 @@ def from_hp_to_intx_static( _layout: Layout = PlainLayout(), ): if target_dtype not in FP8_TYPES: - assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types" - assert zero_point is not None, "zero_point must be specified for non-fp8 types" + assert ( + zero_point_domain is not None + ), "zero_point_domain must be specified for non-fp8 types" + assert ( + zero_point is not None + ), "zero_point must be specified for non-fp8 types" original_shape = input_float.shape - input_float, scale, zero_point = _layout.pre_process_static(input_float, scale, zero_point, block_size) + input_float, scale, zero_point = _layout.pre_process_static( + input_float, scale, zero_point, block_size + ) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) int_data = _layout.post_process(int_data) @@ -252,7 +321,6 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -270,7 +338,9 @@ def from_hp_to_floatx( use_hqq=False, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx" + ) @classmethod def from_hp_to_floatx_static( @@ -281,7 +351,6 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, _layout: Layout, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -295,7 +364,9 @@ def from_hp_to_floatx_static( _layout=_layout, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" + ) @classmethod def from_hp_to_fpx( @@ -304,7 +375,10 @@ def from_hp_to_fpx( _layout: Layout, ): from torchao.dtypes.floatx import FloatxTensorCoreLayout - assert isinstance(_layout, FloatxTensorCoreLayout), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" original_shape = input_float.shape input_float = _layout.pre_process(input_float) # per axis quantization, where axis = 1 @@ -319,12 +393,7 @@ def from_hp_to_fpx( tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - dtype=input_float.dtype - ) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) @property def _layout(self) -> Layout: @@ -388,6 +457,7 @@ class PlainAQTTensorImpl(AQTTensorImpl): scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor """ + def __new__( cls, int_data: torch.Tensor, @@ -424,8 +494,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - _layout, = tensor_attributes + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) def to(self, *args, **kwargs): @@ -470,13 +544,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), ) elif dim == 1: - assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self._layout) + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1), + self._layout, + ) else: - raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) raise NotImplementedError( f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index cfdb566279..e801182559 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass from functools import reduce -from typing import Tuple, Optional +from typing import Optional, Tuple import torch from torch import Tensor @@ -24,11 +25,23 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -42,8 +55,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -52,8 +97,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, - 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -89,8 +166,12 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -124,7 +205,9 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -132,7 +215,9 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) dtype = tensor.dtype tensor = tensor.float() @@ -159,8 +244,10 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -231,7 +318,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7 + 57344: 7, }, { # tokens: [65:128] 3072: 9, @@ -242,7 +329,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6 + 57344: 6, }, { # tokens: [129:192] 3072: 6, @@ -253,7 +340,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4 + 57344: 4, }, { # tokens: [193:256] 3072: 9, @@ -264,7 +351,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4 + 57344: 4, }, { # tokens: [257:320] 3072: 7, @@ -275,7 +362,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4 + 57344: 4, }, { # tokens: [321:384] 3072: 3, @@ -286,7 +373,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3 + 57344: 3, }, { # tokens: [385:448] 3072: 5, @@ -297,7 +384,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3 + 57344: 3, }, { # tokens: [449:512] 3072: 2, @@ -308,7 +395,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1 + 57344: 1, }, { # tokens: [513:576] 3072: 2, @@ -319,7 +406,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [577:640] 3072: 5, @@ -330,7 +417,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [641:704] 3072: 3, @@ -341,7 +428,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [705:768] 3072: 3, @@ -352,17 +439,18 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 - } + 57344: 1, + }, ] # quantization api integrations + @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl - """ + """Layout type for FloatxTensorCoreAQTTensorImpl""" + ebits: int mbits: int @@ -390,6 +478,7 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ + def __new__( cls, packed_floatx_data: torch.Tensor, @@ -398,11 +487,16 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -425,12 +519,17 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] - _layout, = tensor_attributes + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) return unpacked_floatx_data, self.scale @classmethod @@ -449,7 +548,9 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -487,7 +588,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), ) raise NotImplementedError( diff --git a/torchao/dtypes/uintx/bitpacking.py b/torchao/dtypes/uintx/bitpacking.py index 244ca437ef..fdff6ef9fb 100644 --- a/torchao/dtypes/uintx/bitpacking.py +++ b/torchao/dtypes/uintx/bitpacking.py @@ -1,22 +1,23 @@ -import torch -from typing import Optional, List from functools import reduce +from typing import List, Optional + +import torch # for selecting the shards from 8 bits maskbits = { 1: (0x01,), 2: (0x03,), 3: (0x03, 0x04), - 4: (0x0f,), - 5: (0x0f, 0x10), - 6: (0x0f, 0x30), - 7: (0x0f, 0x30, 0x40), + 4: (0x0F,), + 5: (0x0F, 0x10), + 6: (0x0F, 0x30), + 7: (0x0F, 0x30, 0x40), } unpack_mask = { - 1: (0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80), - 2: (0x03,0x0c,0x30,0xc0), - 4: (0x0f,0xf0), + 1: (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80), + 2: (0x03, 0x0C, 0x30, 0xC0), + 4: (0x0F, 0xF0), } # size of each shard @@ -41,6 +42,7 @@ 7: (0, 4, 6), } + # for shifting groups left but right if shift is negative def abs_lsh(data, shift): if shift == 0: @@ -61,9 +63,9 @@ def abs_rsh(data, shift): return data >> shift -def pack_cpu(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: +def pack_cpu( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: """ Inputs: data: a tensor of sub byte elements in uint8 @@ -111,7 +113,10 @@ def pack_cpu(data: torch.Tensor, After pack, data went from 8 elements to 6: [[0, 105, 151, 37], [39, 146]] In general this means pack reduces input tensor size from n * 8 to n * elem_size """ - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") output_shape = list(data.shape) @@ -131,9 +136,9 @@ def pack_cpu(data: torch.Tensor, return output -def unpack_cpu(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = -1) -> torch.Tensor: +def unpack_cpu( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = -1 +) -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. @@ -159,31 +164,37 @@ def unpack_cpu(data: List[torch.Tensor], for j in range(scale): output_narrow = output.narrow(dim, j * group_size, group_size) group = data[i] & unpack_mask[bit_size][j] - shift_amt = j * bit_size - rel_pos - output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos))) + output_narrow.copy_( + torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos)) + ) return output + # these are faster on the GPU + def _pack(data, elem_size, scale, dim): - ''' + """ Inner for loop from above pack function - ''' + """ packed_shape = list(data.shape) packed_shape[dim] = packed_shape[dim] // scale packed = torch.zeros(packed_shape, dtype=data.dtype, device=data.device) for i in range(scale): - narrow_slice = data.narrow(dim, data.shape[dim]*i//scale, data.shape[dim] // scale) + narrow_slice = data.narrow( + dim, data.shape[dim] * i // scale, data.shape[dim] // scale + ) packed |= narrow_slice << (elem_size * i) return packed + def _unpack(data, element_size, scale, dim): - ''' + """ Inner for loop from above unpack function - ''' + """ unpacked_shape = list(data.shape) unpacked_shape[dim] *= scale @@ -193,30 +204,57 @@ def _unpack(data, element_size, scale, dim): for i in range(scale): shift_amt = element_size * i - chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits) + unpacked_data.narrow( + dim, + unpacked_data.shape[dim] * i // scale, + unpacked_data.shape[dim] // scale, + ).copy_((data >> shift_amt) & nbits) return unpacked_data -def pack(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: - ''' +def pack( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: + """ a less branching but more compute version so better for gpu - ''' - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + """ + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") container_size = 8 - shards = [(data & maskbits[elem_size][i]) >> shifts[elem_size][i] for i in range(len(maskbits[elem_size]))] - return tuple([_pack(shards[i], numbits[elem_size][i], container_size//numbits[elem_size][i], dim) for i in range(len(maskbits[elem_size]))]) - -def unpack(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = 0) -> torch.Tensor: - ''' + shards = [ + (data & maskbits[elem_size][i]) >> shifts[elem_size][i] + for i in range(len(maskbits[elem_size])) + ] + return tuple( + [ + _pack( + shards[i], + numbits[elem_size][i], + container_size // numbits[elem_size][i], + dim, + ) + for i in range(len(maskbits[elem_size])) + ] + ) + + +def unpack( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = 0 +) -> torch.Tensor: + """ a less branching but more compute version so better for gpu - ''' + """ container_size = 8 # unpack each 4,2,1 bit shard and unshift them back to the correct position - data = [_unpack(data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim) << shifts[elem_size][i] for i in range(len(data))] + data = [ + _unpack( + data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim + ) + << shifts[elem_size][i] + for i in range(len(data)) + ] return reduce(torch.bitwise_or, data) diff --git a/torchao/dtypes/uintx/uint4.py b/torchao/dtypes/uintx/uint4.py index fc6eb2646c..204aefcf3c 100644 --- a/torchao/dtypes/uintx/uint4.py +++ b/torchao/dtypes/uintx/uint4.py @@ -1,7 +1,7 @@ import torch import torch._prims_common as utils import torch.utils._pytree as pytree -from torch.library import impl, Library +from torch.library import Library, impl def down_size(size): @@ -105,7 +105,6 @@ def __new__(cls, elem, **kwargs): ) def __init__(self, elem, **kwargs): - self.elem = elem @classmethod diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index 4cfcd6696b..b47862a7e1 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -1,9 +1,10 @@ -from typing import Tuple, List from dataclasses import dataclass -import torch +from typing import List, Tuple +import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from .bitpacking import pack, unpack + +from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout from torchao.dtypes.utils import ( Layout, ) @@ -46,6 +47,7 @@ class UintxTensor(TorchAOBaseTensor): bit_width (int): number of bits for each element pack_dim: (int) dimension to pack along """ + bits_to_shard = { 1: ["int1_shard"], 2: ["int2_shard"], @@ -55,6 +57,7 @@ class UintxTensor(TorchAOBaseTensor): 6: ["int4_shard", "int2_shard"], 7: ["int4_shard", "int2_shard", "int1_shard"], } + def __new__( cls, shards: List[torch.Tensor], @@ -84,24 +87,28 @@ def __init__( self.pack_dim = pack_dim def get_shards(self): - return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]] + return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] def __repr__(self): return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})" def __tensor_flatten__(self): - return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim] + return self.__class__.bits_to_shard[self.bit_width], [ + self.packed_shape, + self.bit_width, + self.pack_dim, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - shards = list(tensor_data_dict.values()) + shards = list(tensor_data_dict.values()) packed_shape, bit_width, pack_dim = tensor_attributes return cls(shards, packed_shape, bit_width, pack_dim) def get_plain(self): - return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim) + return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) # temporary until kernels on packed tensors are created def apply_transformation(self, fn): @@ -113,18 +120,21 @@ def apply_transformation(self, fn): # temporary until kernels on packed tensors are created def apply_fn_to_shards(self, fn): new_shards = [fn(shard) for shard in self.get_shards()] - return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim) + return self.__class__( + new_shards, self.packed_shape, self.bit_width, self.pack_dim + ) @classmethod def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): - assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" + assert ( + dtype in _DTYPE_TO_BIT_WIDTH.keys() + ), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" bit_width = _DTYPE_TO_BIT_WIDTH[dtype] shards = pack(int_data, bit_width, dim=pack_dim) shape = list(int_data.shape) shape[pack_dim] = shape[pack_dim] * bit_width // 8 return cls(shards, int_data.shape, bit_width, pack_dim) - def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -153,42 +163,52 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) - implements = UintxTensor.implements + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) ) + @implements(aten.view.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) ) + @implements(aten._to_copy.default) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0] - ) + return return_and_correct_aliasing(func, args, kwargs, args[0]) + @implements(aten.sub.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), ) + @implements(aten.mul.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), ) + # quantization api integrations to_uintx = UintxTensor.from_uint8 + @dataclass(frozen=True) class UintxLayout(Layout): dtype: torch.dtype @@ -197,9 +217,9 @@ class UintxLayout(Layout): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) + @register_layout(UintxLayout) class UintxAQTTensorImpl(PlainAQTTensorImpl): - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 01725f7805..1704fdb61f 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,3 @@ -import torch -from typing import Union, Tuple from dataclasses import dataclass from torchao.utils import TorchAOBaseTensor @@ -21,6 +19,8 @@ behaviors when running the same operator, e.g. transpose, quantized_linear. This is the same as layout in PyTorch native Tensor """ + + @dataclass(frozen=True) class Layout: def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -29,7 +29,13 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: def post_process(self, input: torch.Tensor) -> torch.Tensor: return input - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.pre_process(input), scale, zero_point def __repr__(self): @@ -38,16 +44,21 @@ def __repr__(self): def extra_repr(self) -> str: return "" + """ Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default """ + + @dataclass(frozen=True) class PlainLayout(Layout): pass + def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str + def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: """Returns the unflattened shape of the input tensor. Args: @@ -97,4 +108,3 @@ def __repr__(self): data, scale, zero_point = self.get_plain() _layout = self.get_layout() return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})" - From 5d16ff183306c242bf66e858950288df9293f6ad Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 5 Nov 2024 18:06:37 -0800 Subject: [PATCH 03/11] Update pre-commit to match CI/CD (#1227) Update pre-commit to match CI/CD --- .pre-commit-config.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 48a5d301b3..d8c7ef4adf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,9 @@ repos: hooks: # Run the linter. - id: ruff - args: [--fix] + args: + - --fix + - --select + - F,I # Run the formatter. - id: ruff-format From 43cb8d18b45c1a8189ebde3e9653091fe839494a Mon Sep 17 00:00:00 2001 From: Driss Guessous <32754868+drisspg@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:10:28 -0800 Subject: [PATCH 04/11] Fix for weights-only load (#1228) stack-info: PR: https://github.com/pytorch/ao/pull/1228, branch: drisspg/stack/19 --- ruff.toml | 3 + test/prototype/test_low_bit_optim.py | 100 +++++++++++++----- torchao/prototype/low_bit_optim/__init__.py | 11 ++ torchao/prototype/low_bit_optim/adam.py | 30 ++++-- .../prototype/low_bit_optim/cpu_offload.py | 10 +- .../prototype/low_bit_optim/quant_utils.py | 13 ++- .../prototype/low_bit_optim/subclass_4bit.py | 46 ++++++-- .../prototype/low_bit_optim/subclass_8bit.py | 42 ++++++-- .../prototype/low_bit_optim/subclass_fp8.py | 30 ++++-- 9 files changed, 218 insertions(+), 67 deletions(-) diff --git a/ruff.toml b/ruff.toml index 773497eb5c..40a0680ae4 100644 --- a/ruff.toml +++ b/ruff.toml @@ -13,4 +13,7 @@ include = [ "test/dtypes/test_affine_quantized_float.py", "torchao/quantization/weight_tensor_linear_activation_quantization.py", "torchao/dtypes/**/*.py", + "torchao/prototype/low_bit_optim/**.py", + "test/prototype/low_bit_optim/**.py", + ] diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 39f97896bf..f0b608b47d 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -19,7 +19,11 @@ quantize_4bit_with_qmap, _fp32_to_bf16_sr, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_6, +) try: import bitsandbytes as bnb @@ -85,7 +89,9 @@ def test_bf16_stochastic_round(self, device, compile): x_rep = x.view(-1, 1).repeat(1, 100_000) if compile: - x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)(x_rep) + x_rep_bf16 = torch.compile(_fp32_to_bf16_sr, fullgraph=True, dynamic=False)( + x_rep + ) else: x_rep_bf16 = _fp32_to_bf16_sr(x_rep) @@ -96,8 +102,13 @@ def test_bf16_stochastic_round(self, device, compile): class TestOptim(TestCase): - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") - @parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"]) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) + @parametrize( + "optim_name", + ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"], + ) @parametrize("dtype", [torch.float32, torch.bfloat16]) @parametrize("device", _DEVICES) def test_optim_smoke(self, optim_name, dtype, device): @@ -141,19 +152,28 @@ def test_optim_smoke(self, optim_name, dtype, device): torch.testing.assert_close(p2, p1) @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="bitsandbytes 8-bit Adam only works for CUDA", + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) def test_optim_8bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0 block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048 optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size) + optim2 = getattr(low_bit_optim, optim_name)( + model2.parameters(), block_size=block_size + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -173,12 +193,18 @@ def test_optim_8bit_correctness(self, optim_name): # this will not run in CI because we can't install lpmm @pytest.mark.skipif(lpmm is None, reason="lpmm is not available") - @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA" + ) + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3" + ) @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) def test_optim_4bit_correctness(self, optim_name): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1) # lpmm doesn't have Adam. use AdamW with no weight decay instead. @@ -206,17 +232,25 @@ def test_optim_4bit_correctness(self, optim_name): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) @parametrize("offload_grad,grad_accum", [(False, 1), (False, 2), (True, 1)]) def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + model1[0].requires_grad_( + False + ) # make sure it can work in the presence of non-trainable params model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) optim2 = low_bit_optim.CPUOffloadOptimizer( - model2.parameters(), torch.optim.AdamW, offload_gradients=offload_grad, + model2.parameters(), + torch.optim.AdamW, + offload_gradients=offload_grad, ) for _ in range(2): @@ -234,11 +268,17 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): for p1, p2 in zip(model1.parameters(), model2.parameters()): torch.testing.assert_close(p2, p1) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="optim CPU offload requires CUDA") + @pytest.mark.skipif( + not torch.cuda.is_available(), reason="optim CPU offload requires CUDA" + ) def test_optim_cpu_offload_save_load(self): device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - optim1 = low_bit_optim.CPUOffloadOptimizer(model1.parameters(), torch.optim.AdamW) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) + optim1 = low_bit_optim.CPUOffloadOptimizer( + model1.parameters(), torch.optim.AdamW + ) for _ in range(2): x = torch.randn(4, 32, device=device) @@ -253,7 +293,9 @@ def test_optim_cpu_offload_save_load(self): # resume training model2 = copy.deepcopy(model1) - optim2 = low_bit_optim.CPUOffloadOptimizer(model2.parameters(), torch.optim.AdamW) + optim2 = low_bit_optim.CPUOffloadOptimizer( + model2.parameters(), torch.optim.AdamW + ) optim2.load_state_dict(state_dict) for _ in range(2): @@ -273,13 +315,17 @@ def test_optim_cpu_offload_save_load(self): def test_optim_bf16_stochastic_round_correctness(self): device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(2024) - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to( + device + ) model2 = copy.deepcopy(model1).bfloat16() # small LR so that weight update is small # when bf16_stochastic_round=False, the test will fail after 1 iteration optim1 = torch.optim.AdamW(model1.parameters(), lr=1e-5) - optim2 = low_bit_optim._AdamW(model2.parameters(), lr=1e-5, bf16_stochastic_round=True) + optim2 = low_bit_optim._AdamW( + model2.parameters(), lr=1e-5, bf16_stochastic_round=True + ) # overfit on this sample x = torch.randn(4, 32, device=device) @@ -299,7 +345,9 @@ def test_optim_bf16_stochastic_round_correctness(self): optim2.step() optim2.zero_grad() - torch.testing.assert_close(loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}") + torch.testing.assert_close( + loss1, loss2, msg=lambda msg: f"Iteration {idx}. {msg}" + ) class TestFSDP2(FSDPTest): @@ -307,7 +355,9 @@ class TestFSDP2(FSDPTest): def world_size(self) -> int: return 2 - @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required.") + @pytest.mark.skipif( + not TORCH_VERSION_AT_LEAST_2_6, reason="PyTorch>=2.6 is required." + ) @skip_if_lt_x_gpu(2) def test_fsdp2(self): optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit] @@ -363,7 +413,9 @@ def _test_fsdp2(self, optim_cls): base_loss.backward() for param in base_model.parameters(): if param.grad is not None: - torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) base_optim.step() self.assertEqual(fsdp_loss, base_loss) diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py index 4ad75d4abf..9a22507b4c 100644 --- a/torchao/prototype/low_bit_optim/__init__.py +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -1,2 +1,13 @@ from .adam import Adam4bit, Adam8bit, AdamFp8, AdamW4bit, AdamW8bit, AdamWFp8, _AdamW from .cpu_offload import CPUOffloadOptimizer + +__all__ = [ + "Adam4bit", + "Adam8bit", + "AdamFp8", + "AdamW4bit", + "AdamW8bit", + "AdamWFp8", + "_AdamW", + "CPUOffloadOptimizer", +] diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 19e0640334..1c3718972b 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -2,18 +2,28 @@ import torch from torch import Tensor -from torch.optim import Optimizer from torch.distributed._tensor import DTensor +from torch.optim import Optimizer -from .subclass_8bit import OptimState8bit +from .quant_utils import _fp32_to_bf16_sr from .subclass_4bit import OptimState4bit +from .subclass_8bit import OptimState8bit from .subclass_fp8 import OptimStateFp8 -from .quant_utils import _fp32_to_bf16_sr class _AdamBase(Optimizer): def __init__( - self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size, bf16_stochastic_round, is_adamw + self, + params, + lr, + betas, + eps, + weight_decay, + amsgrad, + *, + block_size, + bf16_stochastic_round, + is_adamw, ) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -23,7 +33,13 @@ def __init__( raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=torch.tensor(lr), betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + defaults = dict( + lr=torch.tensor(lr), + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) super().__init__(params, defaults) self.block_size = block_size self.bf16_stochastic_round = bf16_stochastic_round @@ -45,7 +61,9 @@ def _new_buffer(self, p: Tensor, signed: bool): if p.numel() >= 4096 and p.numel() % self.block_size == 0: if isinstance(p, DTensor): out = DTensor.from_local( - local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size), + local_tensor=self._subclass_zeros( + p.to_local(), signed, self.block_size + ), device_mesh=p.device_mesh, placements=p.placements, run_check=False, diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index 6a8671082c..c69932aa4c 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -25,7 +25,11 @@ def __init__( kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW - if optimizer_class is torch.optim.AdamW and TORCH_VERSION_AT_LEAST_2_4 and "fused" not in kwargs: + if ( + optimizer_class is torch.optim.AdamW + and TORCH_VERSION_AT_LEAST_2_4 + and "fused" not in kwargs + ): kwargs.update(fused=True) param_groups = list(params) @@ -77,7 +81,9 @@ def backward_hook(p_cuda): self.param_cuda2cpu_map[p_cuda] = p_cpu p_cuda.register_post_accumulate_grad_hook(backward_hook) - self.optim_dict[p_cuda] = optimizer_class([{"params": p_cpu, **param_group}], **kwargs) + self.optim_dict[p_cuda] = optimizer_class( + [{"params": p_cpu, **param_group}], **kwargs + ) @torch.no_grad() def step(self, closure=None): diff --git a/torchao/prototype/low_bit_optim/quant_utils.py b/torchao/prototype/low_bit_optim/quant_utils.py index 556a2f290c..628c8a742e 100644 --- a/torchao/prototype/low_bit_optim/quant_utils.py +++ b/torchao/prototype/low_bit_optim/quant_utils.py @@ -122,14 +122,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor: # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16 # # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16 - rand_16bit = torch.randint(0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32) + rand_16bit = torch.randint( + 0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32 + ) x_f32_bits = x_f32.view(torch.int32) - x_fraction = x_f32_bits & 0xFFFF # lower 16 bits - x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits + x_fraction = x_f32_bits & 0xFFFF # lower 16 bits + x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits x_f32_bits = torch.where( - rand_16bit < x_fraction, # this is True with the probability of p_fraction - x_bf16_towards_zero + 0x10000, # this might overflow, which will result in UB due to signed integer + rand_16bit < x_fraction, # this is True with the probability of p_fraction + x_bf16_towards_zero + + 0x10000, # this might overflow, which will result in UB due to signed integer x_bf16_towards_zero, ) # alternative, slightly faster diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py index 759d816a6e..e493b978fe 100644 --- a/torchao/prototype/low_bit_optim/subclass_4bit.py +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -3,10 +3,19 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 -from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) +from .quant_utils import ( + create_dynamic_map, + dequant_with_qmap, + quantize_4bit_with_qmap, + scale_tensor, +) aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -55,8 +64,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed, self._shape] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) # unpack @@ -85,6 +98,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -158,16 +172,20 @@ def _(func, types, args, kwargs): if len(shape) == 1 and shape[0] == -1: return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),)) - raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]") + raise ValueError( + f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) # this is needed for DTensor.full_tensor() -@OptimState4bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState4bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState4bit): @@ -181,3 +199,9 @@ def _(func, types, args, kwargs): # assume tensors from all ranks have the same signedness return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState4bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index f5374a3480..d23d159645 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,10 +1,19 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_4 -from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, + TorchAOBaseTensor, +) +from .quant_utils import ( + create_dynamic_map, + dequant_with_qmap, + quantize_8bit_with_qmap, + scale_tensor, +) aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -46,8 +55,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [self.signed] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = dequant_with_qmap(self.codes, self.qmap, self.scale) @@ -72,6 +85,7 @@ def __repr__(self): # in pre-2.4, calling .to(device, dtype) will not dispatch aten._to_copy.default when # dtype is the same but device is different. thus, we must override .to() method instead. if not TORCH_VERSION_AT_LEAST_2_4: + def _to(self, *args, **kwargs): # ignore other args/kwargs device = kwargs.pop("device", None) @@ -136,12 +150,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimState8bit.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimState8bit.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimState8bit): @@ -154,3 +170,9 @@ def _(func, types, args, kwargs): x.qmap.clone(), x.signed, ) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimState8bit]) diff --git a/torchao/prototype/low_bit_optim/subclass_fp8.py b/torchao/prototype/low_bit_optim/subclass_fp8.py index eabe8b5051..d95b0c2661 100644 --- a/torchao/prototype/low_bit_optim/subclass_fp8.py +++ b/torchao/prototype/low_bit_optim/subclass_fp8.py @@ -1,8 +1,8 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.utils import TorchAOBaseTensor +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor aten = torch.ops.aten c10d_functional = torch.ops.c10d_functional @@ -51,8 +51,12 @@ def __tensor_flatten__(self): return self.tensor_attrs, [] @classmethod - def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): - return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None + ): + return cls( + *[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes + ) def dequantize(self, output_dtype=None): float_data = self.codes.float() @@ -121,12 +125,14 @@ def _(func, types, args, kwargs): # this is needed for DTensor.full_tensor() -@OptimStateFp8.implements([ - c10d_functional.all_gather_into_tensor.default, - _c10d_functional.all_gather_into_tensor.default, - c10d_functional.wait_tensor.default, - _c10d_functional.wait_tensor.default, -]) +@OptimStateFp8.implements( + [ + c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor.default, + c10d_functional.wait_tensor.default, + _c10d_functional.wait_tensor.default, + ] +) def _(func, types, args, kwargs): x = args[0] if not isinstance(x, OptimStateFp8): @@ -137,3 +143,9 @@ def _(func, types, args, kwargs): func(x.codes, *args[1:], **kwargs), func(x.scale, *args[1:], **kwargs), ) + + +if TORCH_VERSION_AT_LEAST_2_5: + from torch.serialization import add_safe_globals + + add_safe_globals([OptimStateFp8]) From 4921bb316a6ccef52a70a78d90640941733af084 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 7 Nov 2024 13:36:42 +0800 Subject: [PATCH 05/11] Add module-swap UX for INT8 mixed-precision training (#1179) * add module swap UX * update * fix typing. add small notes * try NF4 support * fix * fix unpacking * fix * update nf4 integration * update backward pass --- .../quantized_training/pretrain_llama2.py | 3 + test/prototype/test_quantized_training.py | 23 +++-- .../prototype/quantized_training/__init__.py | 1 + .../int8_mixed_precision.py | 95 ++++++++++++++----- 4 files changed, 87 insertions(+), 35 deletions(-) diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index eed90fe9f6..0e6b79f60c 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -160,6 +160,9 @@ def insert_rmsnorm(module: torch.nn.Module): elif args.quantize == "int8_mixed_precision": quantize_(model.layers, int8_mixed_precision_training(), set_inductor_config=False) + elif args.quantize == "int8_mixed_precision_module_swap": + quantize_(model.layers, int8_mixed_precision_training(module_swap=True), set_inductor_config=False) + elif args.quantize == "bitnet": quantize_(model.layers, bitnet_training(), set_inductor_config=False) diff --git a/test/prototype/test_quantized_training.py b/test/prototype/test_quantized_training.py index 7a8d522362..c8bf5d574e 100644 --- a/test/prototype/test_quantized_training.py +++ b/test/prototype/test_quantized_training.py @@ -159,8 +159,9 @@ def test_int8_weight_only_training(self, compile, device): Int8MixedPrecisionTrainingConfig(grad_weight=False), ], ) + @parametrize("module_swap", [False, True]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_int8_mixed_precision_training(self, compile, config): + def test_int8_mixed_precision_training(self, compile, config, module_swap): _reset() bsize = 64 embed_dim = 64 @@ -168,7 +169,8 @@ def test_int8_mixed_precision_training(self, compile, config): linear = nn.Linear(embed_dim, embed_dim, device=device) linear_int8mp = copy.deepcopy(linear) - quantize_(linear_int8mp, int8_mixed_precision_training(config), set_inductor_config=False) + apply_func = int8_mixed_precision_training(config, module_swap=module_swap) + quantize_(linear_int8mp, apply_func, set_inductor_config=False) if compile: linear.compile() @@ -269,9 +271,10 @@ def test_fsdp2_correctness(self): # quantize_fn, mp_policy, tolerance test_args = [ # high tolerance due to stochastic rounding - (int8_weight_only_quantized_training, mp_policy, 0.05), - (int8_mixed_precision_training, mp_policy, 1e-6), - (bitnet_training, mp_policy, 1e-5), + (int8_weight_only_quantized_training(), mp_policy, 0.05), + (int8_mixed_precision_training(), mp_policy, 1e-6), + (int8_mixed_precision_training(module_swap=True), mp_policy, 1e-6), + (bitnet_training(), mp_policy, 1e-5), ] # FSDP2 mixed-precision requires https://github.com/pytorch/pytorch/pull/136129 @@ -284,9 +287,9 @@ def test_fsdp2_correctness(self): bf16_mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) extra_args = [ - (int8_weight_only_quantized_training, bf16_mp_policy, 1e-2), - (int8_mixed_precision_training, bf16_mp_policy, 1e-2), - (bitnet_training, bf16_mp_policy, 1e-2), + (int8_weight_only_quantized_training(), bf16_mp_policy, 1e-2), + (int8_mixed_precision_training(), bf16_mp_policy, 1e-2), + (bitnet_training(), bf16_mp_policy, 1e-2), ] test_args.extend(extra_args) @@ -312,8 +315,8 @@ def _run_subtest(self, args): base_model = Transformer(model_args).cuda() fsdp_model = copy.deepcopy(base_model) - quantize_(base_model.layers, quantize_fn(), set_inductor_config=False) - quantize_(fsdp_model.layers, quantize_fn(), set_inductor_config=False) + quantize_(base_model.layers, quantize_fn, set_inductor_config=False) + quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False) for layer in fsdp_model.layers: fully_shard(layer, mp_policy=mp_policy) diff --git a/torchao/prototype/quantized_training/__init__.py b/torchao/prototype/quantized_training/__init__.py index c3c9b7cfaf..f888c00aa4 100644 --- a/torchao/prototype/quantized_training/__init__.py +++ b/torchao/prototype/quantized_training/__init__.py @@ -6,6 +6,7 @@ ) from .int8_mixed_precision import ( Int8MixedPrecisionTrainingConfig, + Int8MixedPrecisionTrainingLinear, Int8MixedPrecisionTrainingLinearWeight, int8_mixed_precision_training, ) diff --git a/torchao/prototype/quantized_training/int8_mixed_precision.py b/torchao/prototype/quantized_training/int8_mixed_precision.py index b344f9aa13..640057cb00 100644 --- a/torchao/prototype/quantized_training/int8_mixed_precision.py +++ b/torchao/prototype/quantized_training/int8_mixed_precision.py @@ -1,8 +1,8 @@ -from typing import Any, NamedTuple, Optional, Tuple +from typing import Any, NamedTuple, Optional, Tuple, Union import torch import torch.utils._pytree as pytree -from torch import Tensor +from torch import Tensor, nn from torch.utils._triton import has_triton from torchao.quantization.quant_api import _get_linear_subclass_inserter @@ -75,7 +75,7 @@ def to_original(self): def __torch_dispatch__(cls, func, types, args, kwargs): config = None - def unwrap(x: cls): + def unwrap(x): nonlocal config if config is None: config = x.config @@ -151,7 +151,16 @@ def _(func, types, args, kwargs): if torch.is_autocast_enabled("cuda"): dtype = torch.get_autocast_gpu_dtype() args = tuple(x.to(dtype) if x is not None else x for x in args) - return _Int8MixedPrecisionTrainingLinear.apply(*args, **kwargs) + return _Int8MixedPrecisionTrainingLinearFunction.apply(*args, **kwargs) + + +class Int8MixedPrecisionTrainingLinear(nn.Linear): + def __init__(self, *args, config: Int8MixedPrecisionTrainingConfig, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.config = config + + def forward(self, input: Tensor) -> Tensor: + return _Int8MixedPrecisionTrainingLinearFunction.apply(input, self.weight, self.bias, self.config) def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: @@ -184,26 +193,46 @@ def _dynamic_int8_mm(A: Tensor, B: Tensor) -> Tensor: return out.view(*A.shape[:-1], out.shape[-1]) -class _Int8MixedPrecisionTrainingLinear(torch.autograd.Function): +@torch.compiler.allow_in_graph # this is required for module-swap, but not for tensor subclass +class _Int8MixedPrecisionTrainingLinearFunction(torch.autograd.Function): @staticmethod - def forward(input: Tensor, weight: Int8MixedPrecisionTrainingLinearWeight, bias: Optional[Tensor]): - if weight.config.output: - out = _dynamic_int8_mm(input, weight._data.T) + def forward( + ctx, + input: Tensor, + weight: Union[Int8MixedPrecisionTrainingLinearWeight, Tensor], + bias: Optional[Tensor], + config: Optional[Int8MixedPrecisionTrainingConfig] = None, + ): + # unpack tensor subclass and dequant if necessary. + # NOTE: we have to do this inside autograd.Function so that autograd works correctly. + if isinstance(weight, Int8MixedPrecisionTrainingLinearWeight): + config = weight.config # override `config` input argument + weight = weight._data + + ctx.config = config + ctx.save_for_backward(input, weight) + ctx.bias = bias is not None + + # for NF4Tensor, this will dequantize the tensor. + # NOTE: not all quantized tensor subclasses implement .to() this way. + # e.g. AffineQuantizedTensor.to(dtype=dtype) returns the same AQT tensor. + # casting weight dtype may also introduce unintended behavior. + # e.g. FP32 activations and BF16 weight (both plain tensors), which should raise an error, + # but now we cast BF16 weight to FP32 instead (and return results in FP32). + weight = weight.to(input.dtype) + + if config.output: + out = _dynamic_int8_mm(input, weight.T) else: - out = input @ weight._data.T + out = input @ weight.T out = out + bias if bias is not None else out return out - @staticmethod - def setup_context(ctx, inputs, output): - input, weight, bias = inputs - ctx.config = weight.config - ctx.save_for_backward(input, weight._data) - ctx.bias = bias is not None - @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors + weight = weight.to(input.dtype) # dequant NF4 + grad_input = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: @@ -224,12 +253,28 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[2] and ctx.bias: grad_bias = grad_output.sum(0) - return grad_input, grad_weight, grad_bias - - -def int8_mixed_precision_training(config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG): - return _get_linear_subclass_inserter( - Int8MixedPrecisionTrainingLinearWeight, - config=config, - allow_requires_grad=True, - ) + return grad_input, grad_weight, grad_bias, None + + +def int8_mixed_precision_training( + config: Int8MixedPrecisionTrainingConfig = _DEFAULT_CONFIG, + *, + module_swap: bool = False, +): + # TODO: skip small layers that don't have perf gain. + if module_swap: + # module swap implementation + def convert_linear(linear: nn.Linear): + linear.__class__ = Int8MixedPrecisionTrainingLinear + linear.config = config + return linear + + return convert_linear + + else: + # tensor subclass implementation + return _get_linear_subclass_inserter( + Int8MixedPrecisionTrainingLinearWeight, + config=config, + allow_requires_grad=True, + ) From 235052f61efbd5e15cfb55ddf699bd4a0715f6f5 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 6 Nov 2024 08:53:35 -0800 Subject: [PATCH 06/11] Refactored files --- .../floatx/floatx_tensor_core_layout.py | 176 ++++-------------- 1 file changed, 35 insertions(+), 141 deletions(-) diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index e801182559..cfdb566279 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -1,6 +1,5 @@ -from dataclasses import dataclass from functools import reduce -from typing import Optional, Tuple +from typing import Tuple, Optional import torch from torch import Tensor @@ -25,23 +24,11 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce( - torch.bitwise_or, - [ - x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) - for i in range(8 // n_bits) - ], - ) + return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack( - [ - (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) - for i in range(8 // n_bits) - ], - dim=-1, - ).flatten(-2) + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -55,40 +42,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [ - 1, - 5, - 9, - 13, - 17, - 21, - 25, - 29, - 3, - 7, - 11, - 15, - 19, - 23, - 27, - 31, - 0, - 4, - 8, - 12, - 16, - 20, - 24, - 28, - 2, - 6, - 10, - 14, - 18, - 22, - 26, - 30, - ], + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -97,40 +52,8 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [ - 16, - 0, - 24, - 8, - 17, - 1, - 25, - 9, - 18, - 2, - 26, - 10, - 19, - 3, - 27, - 11, - 20, - 4, - 28, - 12, - 21, - 5, - 29, - 13, - 22, - 6, - 30, - 14, - 23, - 7, - 31, - 15, - ], + 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, + 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -166,12 +89,8 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = ( - tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) - ) # Pass 2 from original code - tensor_ybit = _bit_interleave( - tensor_ybit.flatten(), y - ) # Pass 3 from original code + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -205,9 +124,7 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx( - tensor: Tensor, ebits: int, mbits: int -) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -215,9 +132,7 @@ def to_scaled_tc_floatx( # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( - _ONES_TABLE[mbits + 1] / (2**mbits) - ) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) dtype = tensor.dtype tensor = tensor.float() @@ -244,10 +159,8 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = ( - tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) - ) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -318,7 +231,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7, + 57344: 7 }, { # tokens: [65:128] 3072: 9, @@ -329,7 +242,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6, + 57344: 6 }, { # tokens: [129:192] 3072: 6, @@ -340,7 +253,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4, + 57344: 4 }, { # tokens: [193:256] 3072: 9, @@ -351,7 +264,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4, + 57344: 4 }, { # tokens: [257:320] 3072: 7, @@ -362,7 +275,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4, + 57344: 4 }, { # tokens: [321:384] 3072: 3, @@ -373,7 +286,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3, + 57344: 3 }, { # tokens: [385:448] 3072: 5, @@ -384,7 +297,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3, + 57344: 3 }, { # tokens: [449:512] 3072: 2, @@ -395,7 +308,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1, + 57344: 1 }, { # tokens: [513:576] 3072: 2, @@ -406,7 +319,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [577:640] 3072: 5, @@ -417,7 +330,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [641:704] 3072: 3, @@ -428,7 +341,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1, + 57344: 1 }, { # tokens: [705:768] 3072: 3, @@ -439,18 +352,17 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1, - }, + 57344: 1 + } ] # quantization api integrations - @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl""" - + """Layout type for FloatxTensorCoreAQTTensorImpl + """ ebits: int mbits: int @@ -478,7 +390,6 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ - def __new__( cls, packed_floatx_data: torch.Tensor, @@ -487,16 +398,11 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = ( - packed_floatx_data.shape[0], - packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, - ) + shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_floatx_data.layout + kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -519,17 +425,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = ( - tensor_data_dict["packed_floatx_data"], - tensor_data_dict["scale"], - ) - (_layout,) = tensor_attributes + packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] + _layout, = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx( - self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits - ) + unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) return unpacked_floatx_data, self.scale @classmethod @@ -548,9 +449,7 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx( - unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits - ) + packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -588,12 +487,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: x.to(device=kwargs.pop("device", None)) - ), + func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), ) raise NotImplementedError( From 7bc89bc6446f5ada458af9e200fdce86c3f387cd Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 6 Nov 2024 23:39:39 -0800 Subject: [PATCH 07/11] Updated with ruff check --- torchao/dtypes/__init__.py | 27 +- torchao/dtypes/affine_quantized_tensor.py | 148 +--------- torchao/dtypes/affine_quantized_tensor_ops.py | 155 +++++++--- torchao/dtypes/floatx/__init__.py | 18 +- torchao/dtypes/floatx/float8_layout.py | 122 +++++--- .../floatx/floatx_tensor_core_layout.py | 207 ++++++++++---- torchao/dtypes/uintx/__init__.py | 26 +- torchao/dtypes/uintx/block_sparse_layout.py | 55 ++-- torchao/dtypes/uintx/marlin_sparse_layout.py | 95 +++++-- torchao/dtypes/uintx/plain_layout.py | 264 ++++++++++++++++++ torchao/dtypes/uintx/semi_sparse_layout.py | 49 ++-- .../dtypes/uintx/tensor_core_tiled_layout.py | 136 ++++++--- torchao/dtypes/uintx/uint8.py | 114 -------- torchao/dtypes/uintx/uint8_layout.py | 118 -------- .../uintx/{uintx.py => uintx_layout.py} | 9 +- torchao/dtypes/utils.py | 5 +- torchao/quantization/quant_api.py | 2 +- 17 files changed, 937 insertions(+), 613 deletions(-) create mode 100644 torchao/dtypes/uintx/plain_layout.py delete mode 100644 torchao/dtypes/uintx/uint8.py delete mode 100644 torchao/dtypes/uintx/uint8_layout.py rename torchao/dtypes/uintx/{uintx.py => uintx_layout.py} (97%) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d7cd517650..e22bcb1253 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,5 @@ from .nf4tensor import NF4Tensor, to_nf4 + # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uintx import UInt4Tensor from .affine_quantized_tensor import ( @@ -9,12 +10,11 @@ to_affine_quantized_fpx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, - PlainAQTTensorImpl, ) -from .affine_quantized_tensor_ops import * + +from . import affine_quantized_tensor_ops from .utils import ( Layout, - MarlinSparseLayout, PlainLayout, ) from .floatx import ( @@ -22,10 +22,20 @@ Float8AQTTensorImpl, ) from .uintx import ( + UintxTensor, + UintxLayout, + UintxAQTTensorImpl, + to_uintx, + _DTYPE_TO_BIT_WIDTH, + _BIT_WIDTH_TO_DTYPE, + UInt4Tensor, SemiSparseLayout, TensorCoreTiledLayout, MarlinSparseLayout, + PlainAQTTensorImpl, + BlockSparseLayout, ) + __all__ = [ "NF4Tensor", "to_nf4", @@ -43,4 +53,15 @@ "Float8Layout", "Float8AQTTensorImpl", "MarlinSparseLayout", + "PlainAQTTensorImpl", + "affine_quantized_tensor_ops", + "BlockSparseLayout", + "to_uintx", + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "_DTYPE_TO_BIT_WIDTH", + "_BIT_WIDTH_TO_DTYPE", + "Uint4Tensor", + "PlainAQTTensorImpl", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index cb2527076d..54c0c5a9c7 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,10 +1,8 @@ -from dataclasses import dataclass import logging import math from typing import Optional, Tuple, Union import torch -from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import Layout, PlainLayout from torchao.quantization.quant_primitives import ( FP8_TYPES, @@ -29,6 +27,7 @@ logger = logging.getLogger(__name__) aten = torch.ops.aten + ############################## # Tensor Subclass Definition # ############################## @@ -445,151 +444,6 @@ def _apply_fn_to_data(self, fn): register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor - -@register_layout(PlainLayout) -class PlainAQTTensorImpl(AQTTensorImpl): - """ - TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point - tensors directly as plain tensors. - - fields: - int_data (torch.Tensor): the quantized integer data Tensor - scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor - zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - """ - - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - kwargs["dtype"] = int_data.dtype - kwargs["requires_grad"] = False - shape = int_data.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - _layout: Layout, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self._layout = _layout - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self._layout] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = ( - tensor_data_dict["int_data"], - tensor_data_dict["scale"], - tensor_data_dict["zero_point"], - ) - (_layout,) = tensor_attributes - return cls(int_data, scale, zero_point, _layout) - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), - self._layout, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - elif func is aten.t.default: - tensor = args[0] - new = tensor.__class__( - tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout - ) - return return_and_correct_aliasing(func, args, kwargs, new) - - elif func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim == 0: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0]._apply_fn_to_data( - lambda x: aten.slice.Tensor(x, dim, start, end, step) - ), - ) - elif dim == 1: - assert ( - len(self.scale.shape) == 1 - ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl( - aten.slice.Tensor(self.int_data, dim, start, end, step), - self.scale.view(-1), - self.zero_point.view(-1), - self._layout, - ) - else: - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - raise NotImplementedError( - f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.int_data, self.scale, self.zero_point - - def get_layout(self) -> Layout: - return self._layout - - @classmethod - def from_plain( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, PlainLayout) - return cls(int_data, scale, zero_point, _layout) - - ##################################################### # torch functional and aten operator implementation # ##################################################### diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a3d995d653..ea62a77065 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -1,9 +1,12 @@ import logging -from typing import Optional, Tuple import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.affine_quantized_tensor import * +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + dequantize_affine, +) +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl # from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, @@ -31,7 +34,7 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.dtypes.uintx.uint8_layout import ( +from torchao.dtypes.uintx.plain_layout import ( _linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl, _linear_fp_act_int8_weight_check, @@ -71,11 +74,14 @@ def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: - logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + logger.warn( + f"Attempting to remove non-existant dispatch condition {dispatch_condition}" + ) class QuantizedLinearNotImplementedError(NotImplementedError): - """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" + pass @@ -84,14 +90,15 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") + raise QuantizedLinearNotImplementedError( + "No specialized dispatch found for quantized linear op" + ) # Attach the _quantized_linear_op to the AffineQuantizedTensor class AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op - # # following are a list of (dispatch_condition, implementation) functions that takes the following args: # # input_tensor: dimension is (M1, M2, ..., in_features) # # weight_tensor: dimension is (out_features, in_features) @@ -100,14 +107,26 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), + ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, + ), + ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, + ), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), - (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, + ), + ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) @@ -125,7 +144,9 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -134,7 +155,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -148,19 +173,31 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): # new_arg1 = args[1].dequantize() # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) idx = args[0] int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so # we need to increase block size to correct dim - new_blocks = idx.dim()-1 + new_blocks = idx.dim() - 1 return dequantize_affine( sliced_data, - new_blocks*[1]+list(args[1].block_size), + new_blocks * [1] + list(args[1].block_size), sliced_scale, sliced_zero_point, sliced_data.dtype, @@ -179,7 +216,9 @@ def _(func, types, args, kwargs): args[0], ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -189,7 +228,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -201,20 +244,22 @@ def _(func, types, args, kwargs): @implements(aten.mm.default) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None - ) + input_tensor, weight_tensor, bias = (args[0], args[1], None) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -256,7 +301,14 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.tensor_impl.t(), + transposed_block_size, + shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -271,11 +323,22 @@ def _(func, types, args, kwargs): shape = list(self.shape) shape[dim] = end - start block_size = self.block_size - assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + assert ( + len(block_size) == 2 + ), f"Slice only works for 2d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__( + aten.slice.Tensor(self.tensor_impl, dim, start, end, step), + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -285,11 +348,31 @@ def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__( + self.tensor_impl, + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) - - raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") + return self.__class__( + self.tensor_impl, + block_size, + (self.numel(),), + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + raise ValueError( + f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 34b7ec1f91..ddfa9e3669 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,2 +1,18 @@ -from .floatx_tensor_core_layout import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx_tensor_core_layout import ( + FloatxTensorCoreLayout, + FloatxTensorCoreAQTTensorImpl, + to_scaled_tc_floatx, + from_scaled_tc_floatx, + _SPLIT_K_MAP, +) from .float8_layout import Float8AQTTensorImpl, Float8Layout + +__all__ = [ + "FloatxTensorCoreLayout", + "FloatxTensorCoreAQTTensorImpl", + "to_scaled_tc_floatx", + "from_scaled_tc_floatx", + "_SPLIT_K_MAP", + "Float8AQTTensorImpl", + "Float8Layout", +] diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 1c3c046497..bf3f96dca3 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -1,9 +1,9 @@ import torch -from torchao.utils import _is_float8_type -from torchao.dtypes.utils import Layout, AQTTensorImpl +from torchao.utils import _is_float8_type, fill_defaults +from torchao.dtypes.utils import Layout, AQTTensorImpl, get_out_shape from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - register_layout + register_layout, ) from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -11,12 +11,13 @@ preprocess_data, Float8MMConfig, addmm_float8_unwrapped_inference, - _is_rowwise_scaled + _is_rowwise_scaled, ) from torch.utils._python_dispatch import ( return_and_correct_aliasing, is_traceable_wrapper_subclass, ) + aten = torch.ops.aten @@ -24,6 +25,7 @@ class Float8Layout(Layout): mm_config: Optional[Float8MMConfig] = None + @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ @@ -32,6 +34,7 @@ class Float8AQTTensorImpl(AQTTensorImpl): Note: technically we should not create a new layout for float8 we should merge this into plain layout """ + float8_data: torch.Tensor scale: torch.Tensor transposed: bool @@ -66,7 +69,7 @@ def __init__( self._layout = _layout def _apply_fn_to_data(self, fn): - """ Applys a fn to all tensor components stored on this class""" + """Applys a fn to all tensor components stored on this class""" return self.__class__( fn(self.float8_data), fn(self.scale), @@ -91,7 +94,10 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, _layout, = tensor_attributes + ( + transposed, + _layout, + ) = tensor_attributes return cls(float8_data, scale, transposed, _layout) @classmethod @@ -115,23 +121,50 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: - #TODO: scale replecation should be dependent on block size + # TODO: scale replecation should be dependent on block size if self.scale.ndim == 1: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), ) elif self.scale.ndim == 0: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor(self.float8_data, dim, start, end, step), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" + ) elif dim == 1: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous(), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -153,42 +186,50 @@ def from_plain( zero_point: Optional[torch.Tensor], _layout: Layout, ): - """ Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + """Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type( + data.dtype + ), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance( + _layout, Float8Layout + ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" return cls(data, scale, False, _layout) def __repr__(self): float8_data, scale, _ = self.get_plain() _layout = self.get_layout() - return (f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})") + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})" + ) ########################## # Float8 Dispatch Kernels ########################## + def _linear_fp8_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], - weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt._layout, Float8Layout) + isinstance(aqt, AffineQuantizedTensor) + and isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) + return check_aqt(input_tensor) and check_aqt(weight_tensor) def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """ Ensures input tensor is correctly formated for _scaled_mm """ + """Ensures input tensor is correctly formated for _scaled_mm""" input_scale = input_scale.unsqueeze(-1) if input_scale.dim() > 2: @@ -196,9 +237,10 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale + def _linear_fp8_act_fp8_weight_impl( - input_tensor: 'AffineQuantizedTensor', - weight_tensor: 'AffineQuantizedTensor', + input_tensor: "AffineQuantizedTensor", + weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" @@ -219,7 +261,9 @@ def _linear_fp8_act_fp8_weight_impl( # Handle rowwise case if _is_rowwise_scaled(weight_tensor): - assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + assert _is_rowwise_scaled( + input_tensor + ), "Input tensor must be rowwise block size" w_scale = w_scale.unsqueeze(-1).T input_scale = preprocess_scale(input_scale, input_tensor.shape) @@ -237,25 +281,31 @@ def _linear_fp8_act_fp8_weight_impl( use_fast_accum=scaled_mm_config.use_fast_accum, ).reshape(out_shape) + def _linear_fp_act_fp8_weight_check( - input_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], - weight_tensor: Union[torch.Tensor, 'AffineQuantizedTensor'], + input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], + weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, Float8Layout) + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) + and ( + weight_tensor.shape == weight_tensor.block_size + or _is_rowwise_scaled(weight_tensor) + ) ) + def _linear_fp_act_fp8_weight_impl( input_tensor: torch.Tensor, - weight_tensor: 'AffineQuantizedTensor', + weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index cfdb566279..b23010878e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -7,7 +7,11 @@ return_and_correct_aliasing, is_traceable_wrapper_subclass, ) -from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, +) from torchao.dtypes.utils import ( Layout, AQTTensorImpl, @@ -24,11 +28,23 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -42,8 +58,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -52,8 +100,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, - 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -89,8 +169,12 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -124,7 +208,9 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -132,7 +218,9 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) dtype = tensor.dtype tensor = tensor.float() @@ -159,8 +247,10 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -231,7 +321,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7 + 57344: 7, }, { # tokens: [65:128] 3072: 9, @@ -242,7 +332,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6 + 57344: 6, }, { # tokens: [129:192] 3072: 6, @@ -253,7 +343,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4 + 57344: 4, }, { # tokens: [193:256] 3072: 9, @@ -264,7 +354,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4 + 57344: 4, }, { # tokens: [257:320] 3072: 7, @@ -275,7 +365,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4 + 57344: 4, }, { # tokens: [321:384] 3072: 3, @@ -286,7 +376,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3 + 57344: 3, }, { # tokens: [385:448] 3072: 5, @@ -297,7 +387,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3 + 57344: 3, }, { # tokens: [449:512] 3072: 2, @@ -308,7 +398,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1 + 57344: 1, }, { # tokens: [513:576] 3072: 2, @@ -319,7 +409,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [577:640] 3072: 5, @@ -330,7 +420,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [641:704] 3072: 3, @@ -341,7 +431,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [705:768] 3072: 3, @@ -352,17 +442,18 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 - } + 57344: 1, + }, ] # quantization api integrations + @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl - """ + """Layout type for FloatxTensorCoreAQTTensorImpl""" + ebits: int mbits: int @@ -390,6 +481,7 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ + def __new__( cls, packed_floatx_data: torch.Tensor, @@ -398,11 +490,16 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -425,12 +522,17 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] - _layout, = tensor_attributes + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) return unpacked_floatx_data, self.scale @classmethod @@ -449,7 +551,9 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -487,7 +591,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), ) raise NotImplementedError( @@ -502,28 +611,28 @@ def get_layout(self) -> Layout: def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import FloatxTensorCoreLayout + return ( # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype in (torch.float16, torch.bfloat16) and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and - ( + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 3) or + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 1) + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) ) ) + def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 7b2c4e9028..e9eca3a011 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,6 +1,30 @@ -from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .uintx_layout import ( + UintxTensor, + UintxLayout, + UintxAQTTensorImpl, + to_uintx, + _DTYPE_TO_BIT_WIDTH, + _BIT_WIDTH_TO_DTYPE, +) from .uint4 import UInt4Tensor from .block_sparse_layout import BlockSparseLayout from .semi_sparse_layout import SemiSparseLayout from .marlin_sparse_layout import MarlinSparseLayout from .tensor_core_tiled_layout import TensorCoreTiledLayout +from .plain_layout import PlainAQTTensorImpl + + +__all__ = [ + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "to_uintx", + "UInt4Tensor", + "BlockSparseLayout", + "SemiSparseLayout", + "MarlinSparseLayout", + "TensorCoreTiledLayout", + "_DTYPE_TO_BIT_WIDTH", + "_BIT_WIDTH_TO_DTYPE", + "PlainAQTTensorImpl", +] diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 4f6358fae5..8355149cf1 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -13,18 +13,20 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, - PlainAQTTensorImpl ) -from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range logger = logging.getLogger(__name__) aten = torch.ops.aten + @dataclass(frozen=True) class BlockSparseLayout(Layout): blocksize: int = 64 + @register_layout(BlockSparseLayout) class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] @@ -33,7 +35,13 @@ class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] @staticmethod def __new__( # noqa: PYI034 @@ -115,17 +123,23 @@ def from_plain(cls, int_data, scale, zero_point, _layout): bsr_values=bsr_tensor.values(), scale=scale, zero_point=zero_point, - _layout = _layout, + _layout=_layout, requires_grad=False, ) def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) return int_data_expanded, self.scale, self.zero_point def _apply_fn_to_data(self, func): return self.__class__( - shape = self.shape, + shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), @@ -166,16 +180,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) - def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, BlockSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) ) @@ -187,12 +200,14 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() - y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1)) + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) y = y.reshape(*y_shape) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index 5483e4d7a3..cac2c70f5c 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -2,24 +2,33 @@ from torchao.dtypes.utils import Layout, AQTTensorImpl from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - register_layout + register_layout, ) import torch from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4 +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, +) + +aten = torch.ops.aten def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - input_tensor.dtype == torch.float16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, MarlinSparseLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and input_tensor.dtype == torch.float16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, MarlinSparseLayout) ) + def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace, const + from torchao.sparsity.marlin import marlin_24_workspace from torchao.ops import marlin_24_gemm assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -39,8 +48,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b workspace_24 = marlin_24_workspace(original_shape[1]) out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, - workspace_24, num_bits, size_m, size_n, size_k + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, ) # Unfold the batch dimension @@ -50,9 +66,9 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b out += bias.to(out.dtype) return out + @dataclass(frozen=True) class MarlinSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format @@ -66,10 +82,12 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: the preprocessed tensor """ from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() w_24, _ = inject_24(input_t, *input_t.shape) return w_24.t() + @register_layout(MarlinSparseLayout) class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ @@ -88,6 +106,7 @@ class MarlinSparseAQTTensorImpl(AQTTensorImpl): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ + @staticmethod def __new__( cls, @@ -144,7 +163,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] + return ["int_data", "scale", "zero_point", "meta"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] @classmethod def __tensor_unflatten__( @@ -155,10 +179,22 @@ def __tensor_unflatten__( zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] _layout, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) + return cls( + int_data, + scale, + zero_point, + meta, + _layout, + original_shape, + group_size, + num_bits, + ) def get_plain(self): - from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + from torchao.sparsity.marlin import ( + unpack_from_marlin_24, + ) # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, self.scale, @@ -179,7 +215,11 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + from torchao.sparsity.marlin import ( + pack_to_marlin_24, + const, + ) # avoid circular import + assert isinstance(_layout, MarlinSparseLayout) # Linear layers are (in_features, out_features) but the int_data that is reaching this point @@ -189,7 +229,7 @@ def from_plain( if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." ) if q_w_24.dtype != torch.int32: @@ -206,14 +246,14 @@ def from_plain( # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main num_bits = 4 if torch.max(q_w_24) < 16 else -1 if num_bits not in [4]: - raise ValueError( - f"Only {[4]} bits are supported, got {num_bits}." - ) + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert group_size <= in_features, "Group size must be less than or equal to in_features." + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -221,12 +261,19 @@ def from_plain( ) # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, _layout, q_w_24.shape, - group_size, num_bits + marlin_24_q_w_comp, + marlin_24_s, + zero_point, + meta, + _layout, + q_w_24.shape, + group_size, + num_bits, ) def get_layout(self) -> Layout: diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py new file mode 100644 index 0000000000..9ce448edaa --- /dev/null +++ b/torchao/dtypes/uintx/plain_layout.py @@ -0,0 +1,264 @@ +import torch +from torchao.dtypes.utils import PlainLayout, AQTTensorImpl, Layout +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.utils import fill_defaults +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) +from torchao.quantization.quant_primitives import ( + int_scaled_matmul, + ZeroPointDomain, +) +from typing import Optional, Tuple + +aten = torch.ops.aten + + +@register_layout(PlainLayout) +class PlainAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for plain layout for affine quantized tensor, it stores int_data, scale, zero_point + tensors directly as plain tensors. + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + _layout: Layout, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self._layout = _layout + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (_layout,) = tensor_attributes + return cls(int_data, scale, zero_point, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + elif func is aten.t.default: + tensor = args[0] + new = tensor.__class__( + tensor.int_data.t(), tensor.scale, tensor.zero_point, tensor._layout + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + elif func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), + ) + elif dim == 1: + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1), + self._layout, + ) + else: + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data, self.scale, self.zero_point + + def get_layout(self) -> Layout: + return self._layout + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, PlainLayout) + return cls(int_data, scale, zero_point, _layout) + + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.tensor_impl.dtype == torch.int8 + and (aqt.quant_min is None or aqt.quant_min == -128) + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and (aqt.quant_max is None or aqt.quant_max == 127) + ) + + +def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + # input is native float tensor + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and + # weight is int8 per channel quantized affine quantized tensor + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) + and len(weight_tensor.shape) == 2 + and len(weight_tensor.block_size) == 2 + and weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # TODO: enable cpu and mps efficient path + # is_cpu and is_mps only, some issue with is_contiguous() currently + # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) + + # per channel int8 weight only quantizated mm + w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() + scale = weight_tensor.tensor_impl.scale + m = torch.mm( + input_tensor.reshape(-1, input_tensor.shape[-1]), + w_vals_int8_t.to(input_tensor.dtype), + ) + y = m * scale.to(m.dtype) + y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) + if bias is not None: + y += bias.to(m.dtype) + return y + + +def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, PlainLayout) + ) + + +def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.tensor_impl.int_data + x_scales = input_tensor.tensor_impl.scale + w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() + w_scales = weight_tensor.tensor_impl.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast fp16 scale to float to avoid overflow in int_scaled_matmul + intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index 31252701b5..1ac66d4fb2 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -3,25 +3,35 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, - PlainAQTTensorImpl ) import torch from typing import Optional -from torchao.dtypes.uintx.uint8 import _aqt_is_int8_reduced_range +from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl + +aten = torch.ops.aten -def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): + +def _linear_int8_act_int8_weight_semi_structured_sparse_check( + input_tensor, weight_tensor, bias +): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, SemiSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, SemiSparseLayout) ) -def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): +def _linear_int8_act_int8_weight_semi_structured_sparse_impl( + input_tensor, weight_tensor, bias +): x_vals_int8 = input_tensor.tensor_impl.int_data x_scales = input_tensor.tensor_impl.scale w_vals_int8 = weight_tensor.tensor_impl.int_data @@ -29,7 +39,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, + w_vals_int8, + tmp.t(), + alpha=w_scales.to(torch.float32), + out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -41,9 +54,9 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y + @dataclass(frozen=True) class SemiSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() @@ -52,12 +65,12 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return temp - @register_layout(SemiSparseLayout) class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor """ + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -75,10 +88,10 @@ def get_plain(self): # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(cols, - dtype=self.int_data.dtype, - device=self.int_data.device).t()) + int_data_expanded = torch._cslt_sparse_mm( + self.int_data, + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), + ) return int_data_expanded, self.scale, self.zero_point @classmethod diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index 1f6bb92179..f6dfb9a4d2 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -1,45 +1,51 @@ import torch -from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5 -from torchao.dtypes.utils import Layout, AQTTensorImpl -from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor, register_layout +from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5, fill_defaults +from torchao.dtypes.utils import Layout, AQTTensorImpl, is_device +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) from dataclasses import dataclass from typing import Optional, Tuple from torch.utils._python_dispatch import ( return_and_correct_aliasing, is_traceable_wrapper_subclass, ) -from torchao.quantization.quant_primitives import ( - ZeroPointDomain -) +from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params aten = torch.ops.aten + def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.tensor_impl.dtype == torch.int32 and - aqt.quant_min == 0 and - aqt.quant_max == 15 + aqt.tensor_impl.dtype == torch.int32 + and aqt.quant_min == 0 + and aqt.quant_max == 15 ) + def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.dtype == torch.bfloat16 and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - weight_tensor.dtype == torch.bfloat16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor._layout, TensorCoreTiledLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, TensorCoreTiledLayout) ) def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" assert input_tensor.shape[-1] == weight_tensor.shape[1], ( f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " @@ -63,24 +69,27 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: y += bias return y.to(orig_dtype) + @dataclass(frozen=True) class TensorCoreTiledLayout(Layout): """ inner_k_tiles is an internal argument for packing function of tensor core tiled layout that can affect the performance of the matmul kernel """ + inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -93,14 +102,25 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: ) return input - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input = self.pre_process(input) orig_qparam_shape = scale.shape - new_qparam_shape, reduction_dims = _get_reduction_params(block_size, input.size()) + new_qparam_shape, reduction_dims = _get_reduction_params( + block_size, input.size() + ) for dim in reduction_dims: new_qparam_shape.pop(dim) - change_in_qparam_shape = [new_dim_size-orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape)] - padding_changes=[] + change_in_qparam_shape = [ + new_dim_size - orig_dim_size + for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) + ] + padding_changes = [] for dim_change in change_in_qparam_shape: padding_changes = [0, dim_change] + padding_changes scale = torch.nn.functional.pad(scale, padding_changes) @@ -155,7 +175,9 @@ def __new__( kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False @@ -181,8 +203,14 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, _layout, = tensor_attributes + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, _layout) @classmethod @@ -191,20 +219,26 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - _layout: Layout + _layout: Layout, ): - assert isinstance(_layout, TensorCoreTiledLayout) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, _layout) @@ -215,7 +249,9 @@ def to(self, *args, **kwargs): # between these two devices, in the future we should not use the same layout for # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 if not is_device(torch.device(self.device).type, device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") + raise ValueError( + f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" + ) return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -252,7 +288,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ - transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) + transposed = TensorCoreTiledAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) return return_and_correct_aliasing(func, args, kwargs, transposed) if func is aten.slice.Tensor: @@ -277,11 +318,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # this is to handle padding int_data = self._layout.post_process(int_data) scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) sliced = self.from_plain(int_data, scale, zero_point, self._layout) return sliced else: - raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -295,6 +340,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -311,12 +357,26 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) dequantized = dequantized.t().contiguous() # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) return int_data, scale, zero def get_layout(self) -> Layout: diff --git a/torchao/dtypes/uintx/uint8.py b/torchao/dtypes/uintx/uint8.py deleted file mode 100644 index 8d53e93e74..0000000000 --- a/torchao/dtypes/uintx/uint8.py +++ /dev/null @@ -1,114 +0,0 @@ -import torch -from torchao.dtypes.utils import PlainLayout -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from dataclasses import dataclass -from typing import Optional, Tuple, Union -from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - _is_rowwise_scaled -) -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, - is_traceable_wrapper_subclass, -) - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y diff --git a/torchao/dtypes/uintx/uint8_layout.py b/torchao/dtypes/uintx/uint8_layout.py deleted file mode 100644 index 6be1bc25ee..0000000000 --- a/torchao/dtypes/uintx/uint8_layout.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -from torchao.dtypes.utils import PlainLayout -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, -) -from dataclasses import dataclass -from typing import Optional, Tuple, Union -from torchao.float8.inference import ( - preprocess_data, - Float8MMConfig, - addmm_float8_unwrapped_inference, - _is_rowwise_scaled -) -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, - is_traceable_wrapper_subclass, -) -from torchao.quantization.quant_primitives import ( - int_scaled_matmul, - ZeroPointDomain, -) - -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) - ) - - -def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) - ) - - -def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # TODO: enable cpu and mps efficient path - # is_cpu and is_mps only, some issue with is_contiguous() currently - # return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_tensor.tensor_impl.scale) - - # per channel int8 weight only quantizated mm - w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() - scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype - m = torch.mm( - input_tensor.reshape(-1, input_tensor.shape[-1]), - w_vals_int8_t.to(input_tensor.dtype), - ) - y = m * scale.to(m.dtype) - y = y.reshape(*input_tensor.shape[:-1], y.shape[-1]) - if bias is not None: - y += bias.to(m.dtype) - return y - - -def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): - return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) - ) - -def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.tensor_impl.int_data - x_scales = input_tensor.tensor_impl.scale - w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t() - w_scales = weight_tensor.tensor_impl.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype - # Cast fp16 scale to float to avoid overflow in int_scaled_matmul - intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx_layout.py similarity index 97% rename from torchao/dtypes/uintx/uintx.py rename to torchao/dtypes/uintx/uintx_layout.py index b47862a7e1..11bf2f88c9 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -3,16 +3,13 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing - -from torchao.dtypes.affine_quantized_tensor import PlainAQTTensorImpl, register_layout +from .bitpacking import pack, unpack +from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl from torchao.dtypes.utils import ( Layout, ) from torchao.utils import TorchAOBaseTensor -from torchao.dtypes.affine_quantized_tensor import ( - register_layout, - PlainAQTTensorImpl -) from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 aten = torch.ops.aten diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 1704fdb61f..38976176d8 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,5 +1,7 @@ from dataclasses import dataclass from torchao.utils import TorchAOBaseTensor +import torch +from typing import Tuple, Union """ Base class for different layout, following the same design of PyTorch layout @@ -82,6 +84,7 @@ class AQTTensorImpl(TorchAOBaseTensor): Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct the underlying implementation of a AQT based on layout """ + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the tensor impl @@ -101,7 +104,7 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - """ Construct a TensorImpl from data, scale, zero_point and the _layout""" + """Construct a TensorImpl from data, scale, zero_point and the _layout""" pass def __repr__(self): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e021556ed3..611d11287e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -23,7 +23,7 @@ from typing import Any, Callable, Union, Dict, Optional, Literal, Tuple import types -from torchao.dtypes.uintx.uintx import UintxLayout +from torchao.dtypes.uintx.uintx_layout import UintxLayout from torchao.dtypes import ( to_affine_quantized_intx, to_affine_quantized_floatx, From 84b4d38ee8c4c792bcbbd9ea7fba29d2071c9d4e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 11 Nov 2024 18:24:19 -0800 Subject: [PATCH 08/11] Reformatted AQT --- test/dtypes/test_floatx.py | 2 +- test/dtypes/test_uint4.py | 2 +- torchao/dtypes/__init__.py | 45 +++++++++---------- torchao/dtypes/affine_quantized_tensor.py | 10 +++-- torchao/dtypes/affine_quantized_tensor_ops.py | 16 ++++--- torchao/dtypes/floatx/__init__.py | 8 ++-- torchao/dtypes/floatx/float8_layout.py | 22 ++++----- .../floatx/floatx_tensor_core_layout.py | 24 +++++----- torchao/dtypes/uintx/__init__.py | 42 ++++++++++++++++- torchao/dtypes/uintx/block_sparse_layout.py | 17 ++++--- torchao/dtypes/uintx/marlin_sparse_layout.py | 16 ++++--- torchao/dtypes/uintx/plain_layout.py | 18 ++++---- torchao/dtypes/uintx/semi_sparse_layout.py | 18 +++++--- .../dtypes/uintx/tensor_core_tiled_layout.py | 18 ++++---- .../uintx/{uint4.py => uint4_layout.py} | 0 torchao/dtypes/uintx/uintx_layout.py | 7 +-- torchao/dtypes/utils.py | 6 ++- torchao/quantization/quant_api.py | 2 +- 18 files changed, 163 insertions(+), 110 deletions(-) rename torchao/dtypes/uintx/{uint4.py => uint4_layout.py} (100%) diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 3e65ea6ab8..13add69a0a 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -15,7 +15,7 @@ to_scaled_tc_floatx, from_scaled_tc_floatx, ) -from torchao.dtypes.floatx.floatx import _pack_tc_floatx, _pack_tc_fp6 +from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6 from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 from torchao.quantization import ( quantize_, diff --git a/test/dtypes/test_uint4.py b/test/dtypes/test_uint4.py index 98fb523d33..432ffebbd2 100644 --- a/test/dtypes/test_uint4.py +++ b/test/dtypes/test_uint4.py @@ -1,5 +1,5 @@ import torch -from torchao.dtypes.uint4 import ( +from torchao.dtypes.uintx.uint4_layout import ( UInt4Tensor, PerChannelSymmetricWeightUInt4Tensor, ) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index ed6100e099..a41fd83408 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,45 +1,40 @@ -from .nf4tensor import NF4Tensor, to_nf4 +from . import affine_quantized_tensor_ops # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor -from .uintx import UInt4Tensor from .affine_quantized_tensor import ( AffineQuantizedTensor, - to_affine_quantized_intx, - to_affine_quantized_intx_static, - # experimental, will be merged into floatx in the future - to_affine_quantized_fpx, to_affine_quantized_floatx, to_affine_quantized_floatx_static, -) - -from . import affine_quantized_tensor_ops -from .utils import ( - Layout, - MarlinSparseLayout, - PlainLayout, + # experimental, will be merged into floatx in the future + to_affine_quantized_fpx, + to_affine_quantized_intx, + to_affine_quantized_intx_static, ) from .floatx import ( - Float8Layout, Float8AQTTensorImpl, + Float8Layout, ) +from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - UintxTensor, - UintxLayout, - UintxAQTTensorImpl, - to_uintx, - _DTYPE_TO_BIT_WIDTH, _BIT_WIDTH_TO_DTYPE, - UInt4Tensor, - SemiSparseLayout, - TensorCoreTiledLayout, + _DTYPE_TO_BIT_WIDTH, + BlockSparseLayout, MarlinSparseLayout, PlainAQTTensorImpl, - BlockSparseLayout, + SemiSparseLayout, + TensorCoreTiledLayout, + UInt4Tensor, + UintxAQTTensorImpl, + UintxLayout, + UintxTensor, + to_uintx, +) +from .utils import ( + Layout, + PlainLayout, ) -from .nf4tensor import NF4Tensor, to_nf4 # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor -from .uint4 import UInt4Tensor __all__ = [ "NF4Tensor", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 54c0c5a9c7..408cb83dcf 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -3,7 +3,12 @@ from typing import Optional, Tuple, Union import torch -from torchao.dtypes.utils import Layout, PlainLayout + +from torchao.dtypes.utils import ( + AQTTensorImpl, + Layout, + PlainLayout, +) from torchao.quantization.quant_primitives import ( FP8_TYPES, MappingType, @@ -20,9 +25,6 @@ TORCH_VERSION_AT_LEAST_2_5, TorchAOBaseTensor, ) -from torchao.dtypes.utils import ( - AQTTensorImpl, -) logger = logging.getLogger(__name__) aten = torch.ops.aten diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ea62a77065..514b909306 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -2,11 +2,12 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, dequantize_affine, ) -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl + # from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, @@ -26,6 +27,13 @@ _linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl, ) +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _linear_fp_act_int8_weight_check, + _linear_fp_act_int8_weight_impl, + _linear_int8_act_int8_weight_check, + _linear_int8_act_int8_weight_impl, +) from torchao.dtypes.uintx.semi_sparse_layout import ( _linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl, @@ -34,12 +42,6 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) -from torchao.dtypes.uintx.plain_layout import ( - _linear_int8_act_int8_weight_check, - _linear_int8_act_int8_weight_impl, - _linear_fp_act_int8_weight_check, - _linear_fp_act_int8_weight_impl, -) from torchao.utils import ( fill_defaults, ) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index ddfa9e3669..6e22186d7f 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,11 +1,11 @@ +from .float8_layout import Float8AQTTensorImpl, Float8Layout from .floatx_tensor_core_layout import ( - FloatxTensorCoreLayout, + _SPLIT_K_MAP, FloatxTensorCoreAQTTensorImpl, - to_scaled_tc_floatx, + FloatxTensorCoreLayout, from_scaled_tc_floatx, - _SPLIT_K_MAP, + to_scaled_tc_floatx, ) -from .float8_layout import Float8AQTTensorImpl, Float8Layout __all__ = [ "FloatxTensorCoreLayout", diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index bf3f96dca3..dd995fb157 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -1,22 +1,24 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union + import torch -from torchao.utils import _is_float8_type, fill_defaults -from torchao.dtypes.utils import Layout, AQTTensorImpl, get_out_shape +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, ) -from dataclasses import dataclass -from typing import Optional, Tuple, Union +from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape from torchao.float8.inference import ( - preprocess_data, Float8MMConfig, - addmm_float8_unwrapped_inference, _is_rowwise_scaled, + addmm_float8_unwrapped_inference, + preprocess_data, ) -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, - is_traceable_wrapper_subclass, -) +from torchao.utils import _is_float8_type, fill_defaults aten = torch.ops.aten diff --git a/torchao/dtypes/floatx/floatx_tensor_core_layout.py b/torchao/dtypes/floatx/floatx_tensor_core_layout.py index 63cb5cb57a..0f67e9826e 100644 --- a/torchao/dtypes/floatx/floatx_tensor_core_layout.py +++ b/torchao/dtypes/floatx/floatx_tensor_core_layout.py @@ -5,24 +5,23 @@ import torch from torch import Tensor from torch.utils._python_dispatch import ( - return_and_correct_aliasing, is_traceable_wrapper_subclass, + return_and_correct_aliasing, ) -from torchao.prototype.custom_fp_utils import ( - _f32_to_floatx_unpacked, - _floatx_unpacked_to_f32, - _n_ones, + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, ) from torchao.dtypes.utils import ( - Layout, AQTTensorImpl, + Layout, ) -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, ) -from dataclasses import dataclass - aten = torch.ops.aten _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -449,8 +448,6 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> # quantization api integrations - - @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): """Layout type for FloatxTensorCoreAQTTensorImpl""" @@ -635,7 +632,6 @@ def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): - from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear act = input_tensor diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index c44803f6d2..1d0d22c0d4 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1 +1,41 @@ -from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .block_sparse_layout import ( + BlockSparseLayout, +) +from .marlin_sparse_layout import ( + MarlinSparseLayout, +) +from .plain_layout import ( + PlainAQTTensorImpl, +) +from .semi_sparse_layout import ( + SemiSparseLayout, +) +from .tensor_core_tiled_layout import ( + TensorCoreTiledLayout, +) +from .uint4_layout import ( + UInt4Tensor, +) +from .uintx_layout import ( + _BIT_WIDTH_TO_DTYPE, + _DTYPE_TO_BIT_WIDTH, + UintxAQTTensorImpl, + UintxLayout, + UintxTensor, + to_uintx, +) + +__all__ = [ + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "to_uintx", + "_DTYPE_TO_BIT_WIDTH", + "_BIT_WIDTH_TO_DTYPE", + "UInt4Tensor", + "PlainAQTTensorImpl", + "BlockSparseLayout", + "MarlinSparseLayout", + "SemiSparseLayout", + "TensorCoreTiledLayout", +] diff --git a/torchao/dtypes/uintx/block_sparse_layout.py b/torchao/dtypes/uintx/block_sparse_layout.py index 8355149cf1..0670986b13 100644 --- a/torchao/dtypes/uintx/block_sparse_layout.py +++ b/torchao/dtypes/uintx/block_sparse_layout.py @@ -1,21 +1,24 @@ -from dataclasses import dataclass import logging +from dataclasses import dataclass from typing import Optional, Tuple import torch from torch.utils._python_dispatch import ( return_and_correct_aliasing, ) -from torchao.dtypes.utils import ( - Layout, - PlainLayout, -) + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, ) -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl -from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import ( + Layout, + PlainLayout, +) logger = logging.getLogger(__name__) diff --git a/torchao/dtypes/uintx/marlin_sparse_layout.py b/torchao/dtypes/uintx/marlin_sparse_layout.py index cac2c70f5c..e37623182a 100644 --- a/torchao/dtypes/uintx/marlin_sparse_layout.py +++ b/torchao/dtypes/uintx/marlin_sparse_layout.py @@ -1,14 +1,16 @@ from dataclasses import dataclass -from torchao.dtypes.utils import Layout, AQTTensorImpl + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, ) -import torch from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4 -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) +from torchao.dtypes.utils import AQTTensorImpl, Layout from torchao.quantization.quant_primitives import ( ZeroPointDomain, ) @@ -28,8 +30,8 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace from torchao.ops import marlin_24_gemm + from torchao.sparsity.marlin import marlin_24_workspace assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -216,8 +218,8 @@ def from_plain( _layout: Layout, ): from torchao.sparsity.marlin import ( - pack_to_marlin_24, const, + pack_to_marlin_24, ) # avoid circular import assert isinstance(_layout, MarlinSparseLayout) diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index 9ce448edaa..cad0aafcc4 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -1,19 +1,21 @@ +from typing import Optional, Tuple + import torch -from torchao.dtypes.utils import PlainLayout, AQTTensorImpl, Layout -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.utils import fill_defaults from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, return_and_correct_aliasing, ) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout from torchao.quantization.quant_primitives import ( - int_scaled_matmul, ZeroPointDomain, + int_scaled_matmul, ) -from typing import Optional, Tuple +from torchao.utils import fill_defaults aten = torch.ops.aten diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index 1ac66d4fb2..e2c94a7a38 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -1,16 +1,20 @@ from dataclasses import dataclass -from torchao.dtypes.utils import Layout, PlainLayout +from typing import Optional + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, register_layout, ) -import torch -from typing import Optional -from torchao.dtypes.uintx.plain_layout import _aqt_is_int8_reduced_range -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, +from torchao.dtypes.uintx.plain_layout import ( + PlainAQTTensorImpl, + _aqt_is_int8_reduced_range, ) -from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl +from torchao.dtypes.utils import Layout, PlainLayout aten = torch.ops.aten diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index f6dfb9a4d2..ced3fc8922 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -1,17 +1,19 @@ -import torch -from torchao.utils import find_multiple, TORCH_VERSION_AT_LEAST_2_5, fill_defaults -from torchao.dtypes.utils import Layout, AQTTensorImpl, is_device -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) from dataclasses import dataclass from typing import Optional, Tuple + +import torch from torch.utils._python_dispatch import ( - return_and_correct_aliasing, is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, ) +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple aten = torch.ops.aten diff --git a/torchao/dtypes/uintx/uint4.py b/torchao/dtypes/uintx/uint4_layout.py similarity index 100% rename from torchao/dtypes/uintx/uint4.py rename to torchao/dtypes/uintx/uint4_layout.py diff --git a/torchao/dtypes/uintx/uintx_layout.py b/torchao/dtypes/uintx/uintx_layout.py index 11bf2f88c9..29c2ae93fe 100644 --- a/torchao/dtypes/uintx/uintx_layout.py +++ b/torchao/dtypes/uintx/uintx_layout.py @@ -3,14 +3,15 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from .bitpacking import pack, unpack + from torchao.dtypes.affine_quantized_tensor import register_layout from torchao.dtypes.uintx.plain_layout import PlainAQTTensorImpl from torchao.dtypes.utils import ( Layout, ) -from torchao.utils import TorchAOBaseTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TorchAOBaseTensor + +from .bitpacking import pack, unpack aten = torch.ops.aten diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 38976176d8..774071f856 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -1,8 +1,10 @@ from dataclasses import dataclass -from torchao.utils import TorchAOBaseTensor -import torch from typing import Tuple, Union +import torch + +from torchao.utils import TorchAOBaseTensor + """ Base class for different layout, following the same design of PyTorch layout https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout, used to represent different diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 476cc229f3..270268933b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,7 +36,7 @@ to_affine_quantized_floatx_static, to_affine_quantized_intx, ) -from torchao.dtypes.uintx.uintx import UintxLayout +from torchao.dtypes.uintx.uintx_layout import UintxLayout from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observer import ( LinearActivationWeightObservedTensor, From 9c531e96afeed549ec04537b737a10aea91cf46d Mon Sep 17 00:00:00 2001 From: jainapurva Date: Wed, 13 Nov 2024 15:57:30 -0800 Subject: [PATCH 09/11] Rebase and lint fixes --- test/hqq/test_hqq_affine.py | 5 ----- torchao/dtypes/affine_quantized_tensor.py | 2 ++ torchao/dtypes/affine_quantized_tensor_ops.py | 10 +++++----- torchao/dtypes/uintx/plain_layout.py | 4 +++- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 7eda0ab5de..39204d97f0 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,12 +1,7 @@ import unittest import torch from torchao.dtypes.affine_quantized_tensor import ( - to_affine_quantized_intx, ZeroPointDomain, - PlainAQTTensorImpl, - PlainLayout, - TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayout, MappingType, ) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index d923df18de..4bbb87ecee 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -140,6 +140,8 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.zero_point_domain, output_dtype=output_dtype, ) + from torchao.dtypes.uintx import TensorCoreTiledLayout + if isinstance(self._layout, TensorCoreTiledLayout): # need to return to original shape if tensor was padded # in preprocessing diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 514b909306..fe7644d922 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -101,11 +101,11 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op -# # following are a list of (dispatch_condition, implementation) functions that takes the following args: -# # input_tensor: dimension is (M1, M2, ..., in_features) -# # weight_tensor: dimension is (out_features, in_features) -# # bias: dimension is (out_features,) -# # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches +# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args: +# input_tensor: dimension is (M1, M2, ..., in_features) +# weight_tensor: dimension is (out_features, in_features) +# bias: dimension is (out_features,) +# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), diff --git a/torchao/dtypes/uintx/plain_layout.py b/torchao/dtypes/uintx/plain_layout.py index cad0aafcc4..ed171634cd 100644 --- a/torchao/dtypes/uintx/plain_layout.py +++ b/torchao/dtypes/uintx/plain_layout.py @@ -11,9 +11,11 @@ register_layout, ) from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout +from torchao.kernel import ( + int_scaled_matmul, +) from torchao.quantization.quant_primitives import ( ZeroPointDomain, - int_scaled_matmul, ) from torchao.utils import fill_defaults From ab8969cfbe7850b894ee3a2304a6065c21a26b71 Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 14 Nov 2024 14:11:32 -0800 Subject: [PATCH 10/11] Fixed some imports --- test/hqq/test_hqq_affine.py | 2 +- torchao/dtypes/affine_quantized_tensor.py | 12 ++++++++++++ torchao/dtypes/affine_quantized_tensor_ops.py | 4 +--- .../_linear_8bit_act_xbit_weight_layout.py | 4 ++-- torchao/prototype/hqq/example.py | 11 +++++------ torchao/prototype/sparsity/superblock/utils.py | 4 ++-- 6 files changed, 23 insertions(+), 14 deletions(-) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index 39204d97f0..a710075183 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,6 +1,6 @@ import unittest import torch -from torchao.dtypes.affine_quantized_tensor import ( +from torchao.quantization.quant_primitives import ( ZeroPointDomain, MappingType, ) diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 6d39aaf4fe..93d2766d1e 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -31,6 +31,18 @@ logger = logging.getLogger(__name__) aten = torch.ops.aten +__all__ = [ + "AffineQuantizedTensor", + "MarlinQQQTensor", + "register_layout", + "to_affine_quantized_intx", + "to_affine_quantized_floatx", + "to_affine_quantized_intx_static", + "to_affine_quantized_floatx_static", + "to_affine_quantized_fpx", + "to_marlinqqq_quantized_intx", +] + ############################## # Tensor Subclass Definition # diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index c4c1e0ca37..bd7ff7d333 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -5,10 +5,7 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, - dequantize_affine, ) - -# from torchao.dtypes.affine_quantized_tensor import AffineQuantizedTensor from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -46,6 +43,7 @@ _linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl, ) +from torchao.quantization.quant_primitives import dequantize_affine from torchao.utils import ( fill_defaults, ) diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py index 9b9b53d5aa..0e6c73343f 100644 --- a/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_layout.py @@ -12,10 +12,10 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.affine_quantized_tensor import ( - AQTTensorImpl, - register_aqt_quantized_linear_dispatch, register_layout, ) +from torchao.dtypes.utils import AQTTensorImpl +from torchao.dtypes.affine_quantized_tensor_ops import register_aqt_quantized_linear_dispatch from torchao.dtypes.utils import Layout from torchao.quantization.quant_primitives import ( MappingType, diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index 07d5dea205..0b562b05e9 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -2,13 +2,12 @@ from torchao.prototype.hqq.core import HQQQuantizer from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, - ZeroPointDomain, - PlainAQTTensorImpl, - PlainLayout, - TensorCoreTiledAQTTensorImpl, - TensorCoreTiledLayout, - MappingType, ) +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + MappingType, +) +from torchao.dtypes import TensorCoreTiledLayout, PlainLayout #Parameters device, compute_dtype = "cuda:0", torch.bfloat16 diff --git a/torchao/prototype/sparsity/superblock/utils.py b/torchao/prototype/sparsity/superblock/utils.py index 9ed38e50d3..e2b546db24 100644 --- a/torchao/prototype/sparsity/superblock/utils.py +++ b/torchao/prototype/sparsity/superblock/utils.py @@ -387,7 +387,7 @@ def accelerate_with_sparsity(model, args): if args.sparsity == "bsr": apply_sparsity(model) if args.quantization: - from torchao.dtypes.affine_quantized_tensor import BlockSparseLayout + from torchao.dtypes import BlockSparseLayout quantize_( model, @@ -401,7 +401,7 @@ def accelerate_with_sparsity(model, args): sparsify_(model, block_sparse_weight(blocksize=args.bsr), superblock_only) elif args.sparsity == "semi_structured": if args.quantization: - from torchao.dtypes.affine_quantized_tensor import SemiSparseLayout + from torchao.dtypes import SemiSparseLayout quantize_( model, From d992432438ece50c0099653f8424c9d5e7a0555e Mon Sep 17 00:00:00 2001 From: jainapurva Date: Thu, 14 Nov 2024 14:58:53 -0800 Subject: [PATCH 11/11] Minor fixes --- test/hqq/test_hqq_affine.py | 2 +- torchao/dtypes/__init__.py | 4 ---- torchao/prototype/hqq/example.py | 2 +- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index a710075183..2f231fbb31 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -1,6 +1,6 @@ import unittest import torch -from torchao.quantization.quant_primitives import ( +from torchao.quantization import ( ZeroPointDomain, MappingType, ) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index be1708be98..d1fbacdcb4 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,6 +1,4 @@ from . import affine_quantized_tensor_ops - -# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .affine_quantized_tensor import ( AffineQuantizedTensor, MarlinQQQTensor, @@ -29,8 +27,6 @@ PlainLayout, ) -# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor - __all__ = [ "NF4Tensor", "to_nf4", diff --git a/torchao/prototype/hqq/example.py b/torchao/prototype/hqq/example.py index 0b562b05e9..eb12b2b45e 100644 --- a/torchao/prototype/hqq/example.py +++ b/torchao/prototype/hqq/example.py @@ -3,7 +3,7 @@ from torchao.dtypes.affine_quantized_tensor import ( to_affine_quantized_intx, ) -from torchao.quantization.quant_primitives import ( +from torchao.quantization import ( ZeroPointDomain, MappingType, )