From 4e4f4df091ce50d1a97a34f156f4b667f894aac4 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 24 Jan 2025 13:43:51 -0500 Subject: [PATCH] Add quick start guide for first time users (#1611) Documentation in torchao has been pretty low-level and geared towards developers so far. This commit adds a basic quick start guide for first time users to get familiar with our main quantization flow. --- .gitignore | 2 +- docs/source/contributor_guide.rst | 2 +- docs/source/getting-started.rst | 4 - docs/source/index.rst | 17 ++-- docs/source/overview.rst | 4 - docs/source/quantization.rst | 6 +- docs/source/quick_start.rst | 136 ++++++++++++++++++++++++++++++ docs/source/sparsity.rst | 6 +- scripts/quick_start.py | 61 ++++++++++++++ 9 files changed, 213 insertions(+), 25 deletions(-) delete mode 100644 docs/source/getting-started.rst delete mode 100644 docs/source/overview.rst create mode 100644 docs/source/quick_start.rst create mode 100644 scripts/quick_start.py diff --git a/.gitignore b/.gitignore index 5fa7064cbe..726d2976f6 100644 --- a/.gitignore +++ b/.gitignore @@ -262,7 +262,7 @@ docs/dev docs/build docs/source/tutorials/* docs/source/gen_modules/* -docs/source/sg_execution_times +docs/source/sg_execution_times.rst # LevelDB files *.sst diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index a69c410e6c..e76b9420d0 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -1,4 +1,4 @@ -torchao Contributor Guide +Contributor Guide ------------------------- .. toctree:: diff --git a/docs/source/getting-started.rst b/docs/source/getting-started.rst deleted file mode 100644 index 70ac60b4a0..0000000000 --- a/docs/source/getting-started.rst +++ /dev/null @@ -1,4 +0,0 @@ -Getting Started -=============== - -TBA diff --git a/docs/source/index.rst b/docs/source/index.rst index 3bbcd203fd..04a53ce454 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,26 +1,25 @@ Welcome to the torchao Documentation -======================================= +==================================== -`torchao `__ is a library for custom data types & optimizations. Quantize and sparsify weights, gradients, optimizers & activations for inference and training using native PyTorch. Please checkout torchao `README `__ for an overall introduction to the library and recent highlight and updates. The documentation here will focus on: - -1. Getting Started -2. Developer Notes -3. API Reference -4. Tutorials +`torchao `__ is a library for custom data types and optimizations. +Quantize and sparsify weights, gradients, optimizers, and activations for inference and training +using native PyTorch. Please checkout torchao `README `__ +for an overall introduction to the library and recent highlight and updates. .. toctree:: :glob: :maxdepth: 1 :caption: Getting Started - getting-started - sparsity + quick_start .. toctree:: :glob: :maxdepth: 1 :caption: Developer Notes + quantization + sparsity contributor_guide .. toctree:: diff --git a/docs/source/overview.rst b/docs/source/overview.rst deleted file mode 100644 index 4c6d532067..0000000000 --- a/docs/source/overview.rst +++ /dev/null @@ -1,4 +0,0 @@ -Overview -======== - -TBA diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index d96a3afc18..b5e34780b7 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1,4 +1,4 @@ -Quantization -============ +Quantization Overview +--------------------- -TBA +Coming soon! diff --git a/docs/source/quick_start.rst b/docs/source/quick_start.rst new file mode 100644 index 0000000000..fea8bb912d --- /dev/null +++ b/docs/source/quick_start.rst @@ -0,0 +1,136 @@ +Quick Start Guide +----------------- + +In this quick start guide, we will explore how to perform basic quantization using torchao. +First, install the latest stable torchao release:: + + pip install torchao + +If you prefer to use the nightly release, you can install torchao using the following +command instead:: + + pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 + +torchao is compatible with the latest 3 major versions of PyTorch, which you will also +need to install (`detailed instructions `__):: + + pip install torch + + +First Quantization Example +========================== + +The main entry point for quantization in torchao is the `quantize_ `__ API. +This function mutates your model inplace to insert the custom quantization logic based +on what the user configures. All code in this guide can be found in this `example script `__. +First, let's set up our toy model: + +.. code:: py + + import copy + import torch + + class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + + # Optional: compile model for faster inference and generation + model = torch.compile(model, mode="max-autotune", fullgraph=True) + model_bf16 = copy.deepcopy(model) + +Now we call our main quantization API to quantize the linear weights +in the model to int4 inplace. More specifically, this applies uint4 +weight-only asymmetric per-group quantization, leveraging the +`tinygemm int4mm CUDA kernel `__ +for efficient mixed dtype matrix multiplication: + +.. code:: py + + # torch 2.4+ only + from torchao.quantization import int4_weight_only, quantize_ + quantize_(model, int4_weight_only(group_size=32)) + +The quantized model is now ready to use! Note that the quantization +logic is inserted through tensor subclasses, so there is no change +to the overall model structure; only the weights tensors are updated, +but `nn.Linear` modules stay as `nn.Linear` modules: + +.. code:: py + + >>> model.linear1 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + + >>> model.linear2 + Linear(in_features=1024, out_features=1024, weight=AffineQuantizedTensor(shape=torch.Size([1024, 1024]), block_size=(1, 32), device=cuda:0, _layout=TensorCoreTiledLayout(inner_k_tiles=8), tensor_impl_dtype=torch.int32, quant_min=0, quant_max=15)) + +First, verify that the int4 quantized model is roughly a quarter of +the size of the original bfloat16 model: + +.. code:: py + + >>> import os + >>> torch.save(model, "/tmp/int4_model.pt") + >>> torch.save(model_bf16, "/tmp/bfloat16_model.pt") + >>> int4_model_size_mb = os.path.getsize("/tmp/int4_model.pt") / 1024 / 1024 + >>> bfloat16_model_size_mb = os.path.getsize("/tmp/bfloat16_model.pt") / 1024 / 1024 + + >>> print("int4 model size: %.2f MB" % int4_model_size_mb) + int4 model size: 1.25 MB + + >>> print("bfloat16 model size: %.2f MB" % bfloat16_model_size_mb) + bfloat16 model size: 4.00 MB + +Next, we demonstrate that not only is the quantized model smaller, +it is also much faster! + +.. code:: py + + from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, + ) + + # Temporary workaround for tensor subclass + torch.compile + # Only needed for torch version < 2.5 + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + + num_runs = 100 + torch._dynamo.reset() + example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) + bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) + int4_time = benchmark_model(model, num_runs, example_inputs) + + print("bf16 mean time: %0.3f ms" % bf16_time) + print("int4 mean time: %0.3f ms" % int4_time) + print("speedup: %0.1fx" % (bf16_time / int4_time)) + +On a single A100 GPU with 80GB memory, this prints:: + + bf16 mean time: 30.393 ms + int4 mean time: 4.410 ms + speedup: 6.9x + + +Next Steps +========== + +In this quick start guide, we learned how to quantize a simple model with +torchao. To learn more about the different workflows supported in torchao, +see our main `README `__. +For a more detailed overview of quantization in torchao, visit +`this page `__. + +Finally, if you would like to contribute to torchao, don't forget to check +out our `contributor guide `__ and our list of +`good first issues `__ on Github! diff --git a/docs/source/sparsity.rst b/docs/source/sparsity.rst index 0bde173b6d..d9986a3227 100644 --- a/docs/source/sparsity.rst +++ b/docs/source/sparsity.rst @@ -1,5 +1,5 @@ -Sparsity --------- +Sparsity Overview +----------------- Sparsity is the technique of removing parameters from a neural network in order to reduce its memory overhead or latency. By carefully choosing how the elements are pruned, one can achieve significant reduction in memory overhead and latency, while paying a reasonably low or no price in terms of model quality (accuracy / f1). @@ -38,7 +38,7 @@ Given a target sparsity pattern, pruning/sparsifying a model can then be thought * **Accuracy** - How can I find a set of sparse weights which satisfy my target sparsity pattern that minimize the accuracy degradation of my model? -* **Perforance** - How can I accelerate my sparse weights for inference and reduce memory overhead? +* **Performance** - How can I accelerate my sparse weights for inference and reduce memory overhead? Our workflow is designed to consist of two parts that answer each question independently: diff --git a/scripts/quick_start.py b/scripts/quick_start.py new file mode 100644 index 0000000000..f2e195fd7e --- /dev/null +++ b/scripts/quick_start.py @@ -0,0 +1,61 @@ +import copy + +import torch + +from torchao.quantization import int4_weight_only, quantize_ +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + benchmark_model, + unwrap_tensor_subclass, +) + +# ================ +# | Set up model | +# ================ + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m: int, n: int, k: int): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +model = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + +# Optional: compile model for faster inference and generation +model = torch.compile(model, mode="max-autotune", fullgraph=True) +model_bf16 = copy.deepcopy(model) + + +# ======================== +# | torchao quantization | +# ======================== + +# torch 2.4+ only +quantize_(model, int4_weight_only(group_size=32)) + + +# ============= +# | Benchmark | +# ============= + +# Temporary workaround for tensor subclass + torch.compile +# Only needed for torch version < 2.5 +if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(model) + +num_runs = 100 +torch._dynamo.reset() +example_inputs = (torch.randn(1, 1024, dtype=torch.bfloat16, device="cuda"),) +bf16_time = benchmark_model(model_bf16, num_runs, example_inputs) +int4_time = benchmark_model(model, num_runs, example_inputs) + +print("bf16 mean time: %0.3f ms" % bf16_time) +print("int4 mean time: %0.3f ms" % int4_time) +print("speedup: %0.1fx" % (bf16_time / int4_time))