diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index a69c410e6c..0b57279c8c 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -125,7 +125,7 @@ On the top of the stack will be the final quantization algorithms and quantizati For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype. -Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section. +Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass`__ tutorial. Weight Only Quantization ######################## @@ -257,220 +257,6 @@ During Save/Load Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc `__ for more details. -Tensor Subclass Developer Guide -=============================== - -We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios. - -Prerequisites -~~~~~~~~~~~~~ -Some externally available resources for tensor subclasses: - -* `tensor subclass doc `__ -* `Edward's podcast about tensor subclasses `__ -* `Tensor subclass zoo `__ - -Why Tensor Subclass? -~~~~~~~~~~~~~~~~~~~~ -There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things: -(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core -(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization -(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques - -Example Code for a new DType -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Please feel free to start with `tutorial `__ for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations. - -Basic Structure -~~~~~~~~~~~~~~~ -A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``, ``__tensor_flatten__``, ``__tensor_unflatten__`` -and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``. - -Here is an example of basic structure:: - # check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437 - from torchao.utils import TorchAOBaseTensor - - class MyDTypeLayout(TorchAOBaseTensor): - # see tutorial code for details - pass - - class MyDtypeTensor(TorchAOBaseTensor): - """We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize - the instance. There is no requirement on what the argument list should look like here, only requirement is - that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call - """ - @staticmethod - def __new__( - cls, - tensor_impl: MyDTypeLayout, - shape: torch.Size, - dtype: Optional[torch.dtype] = None, - ): - ... - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - tensor_impl: MyDTypeLayout, - shape: torch.Size, ... - ): - self.tensor_impl = tensor_impl - - - """`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and - reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define - a Tensor subclass for torch.compile support - """ - def __tensor_flatten__(self): - return ["tensor_impl"], [self.shape] - - """see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride - """ - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - tensor_impl = tensor_data_dict["tensor_impl"] - shape, = tensor_attributes - return cls( - tensor_impl, - shape if outer_size is None else outer_size, - ) - - - """classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype - """ - @classmethod - def from_float( - cls, - input_float: torch.Tensor, - ): - mapping_type = MappingType.SYMMETRIC - block_size = input_float.shape - dtype = torch.int16 - scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype) - int_data = (input_float / scale).to(torch.int8) - tensor_impl = MyDTypeLayout.from_plain(int_data, scale) - return cls(tensor_impl, input_float.shape) - - - """[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is - """ - @property - def _layout(self) -> LayoutType: - return self.tensor_impl._layout - - """There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch: - - __torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear, - tensor.detach, tensor.reshape, tensor.t etc. - - __torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example: - aten.mm, aten.addmm, aten.detach.default, aten.t.default etc. - you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use - some helper functions directly (see next section) - -Operator Support -~~~~~~~~~~~~~~~~ -There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function. - -For a new dtype, we’d like people to define the following decorator:: - if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do: - - implements = my_dtype_tensor_cls.implements - -And we can implement the operator dispatch with the following:: - # Example for torch_function dispatch for torch.nn.functional.linear - def _quantized_linear_op(input_tensor, weight_tensor, bias): - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - - @implements(torch.nn.functional.linear) - def _(*args, **kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - # using try/except here so that we can have a general fallback when input_tensor/weight_tensor - # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to - # make the branches easier to understand in `_quantized_linear_op` - try: - return _quantized_linear_op(input_tensor, weight_tensor, bias) - except NotImplementedError: - if isinstance(input_tensor, MyDtypeTensor): - input_tensor = input_tensor.dequantize() - if isinstance(weight_tensor, MyDtypeTensor): - weight_tensor = weight_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - # Example for aten op dispatch for aten.detach.default - @implements(aten.detach.default) - def _(func, *args, **kwargs): - # `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to - # work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op, - # which is needed for correctness in AOTAutograd. - - # `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass - # of `my_dtype` - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - -What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are: -``__torch_function__``: ``torch.nn.functional.linear`` -``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default`` - -You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details):: - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(10, 10) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) + x - - from torch.overrides import TorchFunctionMode - class TorchFunctionLoggingMode(TorchFunctionMode): - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"TORCH_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchFunctionLoggingMode(): - m(*example_inputs) - - ## Example output - # TORCH_FUNC= - # TORCH_FUNC= - - - from torch.utils._python_dispatch import TorchDispatchMode - class TorchDispatchLoggingMode(TorchDispatchMode): - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - print(f"ATEN_FUNC={str(func)}") - return func(*args, **kwargs) - - with TorchDispatchLoggingMode(): - m(*example_inputs) - - ## Example output - # ATEN_FUNC=aten.t.default - # ATEN_FUNC=aten.addmm.default - # ATEN_FUNC=aten.add.Tensor - - # or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py - -Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes. - -We are still working on a table that talks about for each feature what are the operators that need to be supported. - Adding Efficient Kernels ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/index.rst b/docs/source/index.rst index 3bbcd203fd..938d29efc6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,3 +38,5 @@ Welcome to the torchao Documentation :caption: Tutorials serialization + subclass_basic + subclass_advanced diff --git a/docs/source/sg_execution_times.rst b/docs/source/sg_execution_times.rst new file mode 100644 index 0000000000..dbcbc46d90 --- /dev/null +++ b/docs/source/sg_execution_times.rst @@ -0,0 +1,37 @@ + +:orphan: + +.. _sphx_glr_sg_execution_times: + + +Computation times +================= +**00:00.000** total execution time for 1 file **from all galleries**: + +.. container:: + + .. raw:: html + + + + + + + + .. list-table:: + :header-rows: 1 + :class: table table-striped sg-datatable + + * - Example + - Time + - Mem (MB) + * - :ref:`sphx_glr_tutorials_template_tutorial.py` (``tutorials_source/template_tutorial.py``) + - 00:00.000 + - 0.0 diff --git a/docs/source/subclass_advanced.rst b/docs/source/subclass_advanced.rst new file mode 100644 index 0000000000..f2df5a1cf0 --- /dev/null +++ b/docs/source/subclass_advanced.rst @@ -0,0 +1,4 @@ +Writing Your Own Quantized Tensor (advanced) +-------------------------------------------- + +Coming soon! diff --git a/docs/source/subclass_basic.rst b/docs/source/subclass_basic.rst new file mode 100644 index 0000000000..a3f10a2a9a --- /dev/null +++ b/docs/source/subclass_basic.rst @@ -0,0 +1,189 @@ +Writing Your Own Quantized Tensor +--------------------------------- + +Quantization in torchao is built on the foundation of tensor subclasses. +They are the main extension point for torchao to provide flexible +inference and training support using low precision computation, while +composing with important PyTorch features such as torch.compile, +autograd, and distributed primitives. + +In this tutorial, we will highlight the benefits of leveraging tensor +subclasses compared to module swaps, and walk through a simple example +of how to express quantization using this approach. + +What are Tensor Subclasses? +=========================== + +Tensor subclasses are simply classes that inherit from `torch.Tensor `__. +They allow users to interpose their custom computation logic between existing +ops in their models, such that functions in the top-level torch +namespace like torch.add will continue to work seamlessly. + +An obvious alternative to the tensor subclass approach is module swaps: +simply swap all nn.Linear modules in your model with your custom +Int8QuantizedLinear modules, for example. There are a few important +benefits of using tensor subclasses compared to this approach: + +1. **Finer-grained integration point.** Module swaps intercept + computation at the module level and so will not work for models that + rely on torch functions or variants of native modules (e.g. slightly + modified versions of nn.Linear). In contrast, since tensor subclasses + intercept computation at the function/op level, we will be able to + quantize the model as long as the same function/op is used. + +2. **Better composability.** Composing multiple features using module + swaps is clunky. For example, combining two existing + Int8QuantizedLinear and DistributedLinear modules would require users + to create another linear class that duplicates these functionalities. + Tensor subclasses bypass this problem by simply wrapping one subclass + in another. This can also offer performance benefits if the outer + tensor (e.g. `DTensor `__) + is aware that the inner tensor is quantized, and so can perform + expensive allgather operations using less network and memory + bandwidth. + +3. **Reusing PyTorch components.** It is natural to express quantization + using tensor subclasses since the quantized tensors are simply + torch.Tensors with different dtypes. The model structure does not + change (nn.Linears stay as nn.Linears), and so subsequent + optimization passes can also stay exactly the same as before. + +| +In the rest of the tutorial, we will walk through an example of how to +express quantization using both approaches. For further reading on +tensor subclasses, please refer to: + +- `Tensor subclass documentation `__ +- `Tensor subclass zoo `__ +- `Tensor subclass podcast by Edward Yang `__ + +Quantization with Module Swaps +============================== + +We begin with a simple example of how to implement int8 symmetric weight +only quantization using module swaps. We will use the following function +for quantizing float32 tensors into int8 tensors: + +.. code:: py + + from typing import Tuple + import torch + + def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + +Next, we will create a new QuantizedLinear module that calls this +function to dynamically quantize the weights: + +.. code:: py + + class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.mm(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + +Then, the only thing that’s left is to swap all nn.Linear modules in the +model with our new QuantizedLinear. Let’s use the following toy model +for demonstration purposes: + +.. code:: py + + import copy + + class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + float_model = ToyModel(64, 128, 32).cuda() + quantized_model = copy.deepcopy(float_model) + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in quantized_model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(quantized_model, name, new_linear) + +Verify that the model now uses our QuantizedLinear module. This model is +now ready to use! + +.. code:: py + + >>> print(float_model) + ToyModel( + (linear1): Linear(in_features=64, out_features=128, bias=False) + (linear2): Linear(in_features=128, out_features=32, bias=False) + ) + + >>> print(quantized_model) + ToyModel( + (linear1): QuantizedLinear(in_features=64, out_features=128, bias=False) + (linear2): QuantizedLinear(in_features=128, out_features=32, bias=False) + ) + +An important drawback of this simple approach is flexibility. Currently +this only works for native PyTorch modules, but what if the model has +slightly modified linear modules that, for example, shards the weights +for distribution? It also won’t work with models that directly call the +functional version of linear (torch.nn.functional.linear) instead. + +Further, suppose we want to compose this feature with distribution, +which is also implemented through module swaps. There is no clean way to +do this except to create yet another module that combines both features. +These limitations can be solved with tensor subclasses, which is a more +elegant way to interpose custom computation such as quantization in your +model. + +Quantization with Tensor Subclasses +=================================== + +Coming soon. + +Common Pitfalls +=============== + +Coming soon. + +Next Steps +========== + +In this tutorial, we demonstrated how to build a simple quantized tensor +subclass. This is part one of two tutorials in this series. The +`next post `__ will discuss how to add more advanced +features to your tensor subclass, such as making it trainable, composing +with DTensors, and adding tensor parallelism support. For a more detailed +example of how `AffineQuantizedTensor` in torchao was built using tensor +subclasses, also check out `this example `__. diff --git a/tutorials/examples/quantized_module_swap.py b/tutorials/examples/quantized_module_swap.py new file mode 100644 index 0000000000..8d45c05a5c --- /dev/null +++ b/tutorials/examples/quantized_module_swap.py @@ -0,0 +1,67 @@ +from typing import Tuple +import torch + + +def int8_symmetric_quantize( + fp32_tensor: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Symmetrically quantize the torch.float32 tensor into torch.int8. + Return a 2-tuple of (quantized value, scale). + """ + quant_min = -128 + quant_max = 127 + min_val = torch.amin(fp32_tensor, dim=[1], keepdim=False) + max_val = torch.amax(fp32_tensor, dim=[1], keepdim=False) + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + max_val_pos = torch.max(-min_val_neg, max_val_pos) + scale = max_val_pos / (float(quant_max - quant_min) / 2) + scale = scale.view(fp32_tensor.shape[0], -1) + out = torch.round(fp32_tensor * (1.0 / scale)) + out = torch.clamp(out, quant_min, quant_max).to(torch.int8) + return out, scale + + +class QuantizedLinear(torch.nn.Linear): + """ + Linear module that performs dynamic and symmetric weight-only + int8 quantization. + """ + def forward(self, x: torch.Tensor) -> torch.Tensor: + w_int8, scale = int8_symmetric_quantize(self.weight) + return torch.mm(x, w_int8.t().to(x.dtype)) * scale.t() + + @classmethod + def from_float(cls, mod: torch.nn.Linear): + new_linear = cls(mod.in_features, mod.out_features, mod.bias) + new_linear.weight = mod.weight + return new_linear + + +class ToyModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +if __name__ == "__main__": + + # Set up toy model + model = ToyModel(64, 128, 32).cuda() + example_inputs = torch.randn((1, 64), dtype=torch.float32, device="cuda") + + # Swap torch.nn.Linear with QuantizedLinear + for name, child in model.named_children(): + if type(child) == torch.nn.Linear: + new_linear = QuantizedLinear.from_float(child) + setattr(model, name, new_linear) + + print("quantized model: ", model) + print("output: ", model(example_inputs))