Skip to content

Commit

Permalink
Add module swap -> tensor subclass migration tutorial
Browse files Browse the repository at this point in the history
Adds a migration tutorial from module swap to tensor subclass for
expressing basic quantization. This is a simplified version of
the existing subclass tutorials in torchao, removing layers of
indirection like Layout and TensorImpl for ease of understanding.
This commit also removes overlapping content from the existing
contributor guide.

Work was done with @bdhirsh.
  • Loading branch information
andrewor14 committed Jan 24, 2025
1 parent 4e4f4df commit 39f2ece
Show file tree
Hide file tree
Showing 8 changed files with 861 additions and 215 deletions.
216 changes: 1 addition & 215 deletions docs/source/contributor_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ On the top of the stack will be the final quantization algorithms and quantizati

For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype.

Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section.
Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass <subclass_basic.html>`__ tutorial.

Weight Only Quantization
########################
Expand Down Expand Up @@ -257,220 +257,6 @@ During Save/Load

Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc <https://pytorch.org/ao/stable/serialization.html>`__ for more details.

Tensor Subclass Developer Guide
===============================

We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios.

Prerequisites
~~~~~~~~~~~~~
Some externally available resources for tensor subclasses:

* `tensor subclass doc <pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor>`__
* `Edward's podcast about tensor subclasses <https://podcasts.apple.com/us/podcast/tensor-subclasses-and-pt2/id1566080008?i=1000646728968>`__
* `Tensor subclass zoo <https://github.com/albanD/subclass_zoo>`__

Why Tensor Subclass?
~~~~~~~~~~~~~~~~~~~~
There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things:
(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core
(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization
(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques

Example Code for a new DType
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Please feel free to start with `tutorial <https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py>`__ for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.

Basic Structure
~~~~~~~~~~~~~~~
A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``, ``__tensor_flatten__``, ``__tensor_unflatten__``
and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``.

Here is an example of basic structure::
# check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437
from torchao.utils import TorchAOBaseTensor

class MyDTypeLayout(TorchAOBaseTensor):
# see tutorial code for details
pass

class MyDtypeTensor(TorchAOBaseTensor):
"""We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize
the instance. There is no requirement on what the argument list should look like here, only requirement is
that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call
"""
@staticmethod
def __new__(
cls,
tensor_impl: MyDTypeLayout,
shape: torch.Size,
dtype: Optional[torch.dtype] = None,
):
...
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
tensor_impl: MyDTypeLayout,
shape: torch.Size, ...
):
self.tensor_impl = tensor_impl


"""`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and
reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define
a Tensor subclass for torch.compile support
"""
def __tensor_flatten__(self):
return ["tensor_impl"], [self.shape]

"""see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride
"""
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
tensor_impl = tensor_data_dict["tensor_impl"]
shape, = tensor_attributes
return cls(
tensor_impl,
shape if outer_size is None else outer_size,
)


"""classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
"""
@classmethod
def from_float(
cls,
input_float: torch.Tensor,
):
mapping_type = MappingType.SYMMETRIC
block_size = input_float.shape
dtype = torch.int16
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
int_data = (input_float / scale).to(torch.int8)
tensor_impl = MyDTypeLayout.from_plain(int_data, scale)
return cls(tensor_impl, input_float.shape)


"""[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is
"""
@property
def _layout(self) -> LayoutType:
return self.tensor_impl._layout

"""There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch:

__torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear,
tensor.detach, tensor.reshape, tensor.t etc.

__torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example:
aten.mm, aten.addmm, aten.detach.default, aten.t.default etc.
you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use
some helper functions directly (see next section)

Operator Support
~~~~~~~~~~~~~~~~
There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function.

For a new dtype, we’d like people to define the following decorator::
if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do:

implements = my_dtype_tensor_cls.implements

And we can implement the operator dispatch with the following::
# Example for torch_function dispatch for torch.nn.functional.linear
def _quantized_linear_op(input_tensor, weight_tensor, bias):
if isinstance(input_tensor, MyDtypeTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, MyDtypeTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(torch.nn.functional.linear)
def _(*args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except NotImplementedError:
if isinstance(input_tensor, MyDtypeTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, MyDtypeTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

# Example for aten op dispatch for aten.detach.default
@implements(aten.detach.default)
def _(func, *args, **kwargs):
# `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to
# work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op,
# which is needed for correctness in AOTAutograd.

# `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass
# of `my_dtype`
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are:
``__torch_function__``: ``torch.nn.functional.linear``
``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default``

You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details)::
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(10, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x) + x

from torch.overrides import TorchFunctionMode
class TorchFunctionLoggingMode(TorchFunctionMode):
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
print(f"TORCH_FUNC={str(func)}")
return func(*args, **kwargs)

with TorchFunctionLoggingMode():
m(*example_inputs)

## Example output
# TORCH_FUNC=<built-in function linear>
# TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects>


from torch.utils._python_dispatch import TorchDispatchMode
class TorchDispatchLoggingMode(TorchDispatchMode):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
print(f"ATEN_FUNC={str(func)}")
return func(*args, **kwargs)

with TorchDispatchLoggingMode():
m(*example_inputs)

## Example output
# ATEN_FUNC=aten.t.default
# ATEN_FUNC=aten.addmm.default
# ATEN_FUNC=aten.add.Tensor

# or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py

Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes.

We are still working on a table that talks about for each feature what are the operators that need to be supported.

Adding Efficient Kernels
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,5 @@ for an overall introduction to the library and recent highlight and updates.
:caption: Tutorials

serialization
subclass_basic
subclass_advanced
37 changes: 37 additions & 0 deletions docs/source/sg_execution_times.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

:orphan:

.. _sphx_glr_sg_execution_times:


Computation times
=================
**00:00.000** total execution time for 1 file **from all galleries**:

.. container::

.. raw:: html

<style scoped>
<link href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/5.3.0/css/bootstrap.min.css" rel="stylesheet" />
<link href="https://cdn.datatables.net/1.13.6/css/dataTables.bootstrap5.min.css" rel="stylesheet" />
</style>
<script src="https://code.jquery.com/jquery-3.7.0.js"></script>
<script src="https://cdn.datatables.net/1.13.6/js/jquery.dataTables.min.js"></script>
<script src="https://cdn.datatables.net/1.13.6/js/dataTables.bootstrap5.min.js"></script>
<script type="text/javascript" class="init">
$(document).ready( function () {
$('table.sg-datatable').DataTable({order: [[1, 'desc']]});
} );
</script>

.. list-table::
:header-rows: 1
:class: table table-striped sg-datatable

* - Example
- Time
- Mem (MB)
* - :ref:`sphx_glr_tutorials_template_tutorial.py` (``tutorials_source/template_tutorial.py``)
- 00:00.000
- 0.0
4 changes: 4 additions & 0 deletions docs/source/subclass_advanced.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Writing Your Own Quantized Tensor (advanced)
--------------------------------------------

Coming soon!
Loading

0 comments on commit 39f2ece

Please sign in to comment.