Skip to content

Commit

Permalink
Add convert path for quantize_ QAT API
Browse files Browse the repository at this point in the history
Summary: #1415 added a quantize_
QAT API for the prepare path. This commit adds the remaining
convert path for users to actually perform end-to-end QAT using
the quantize_ API. The new flow will look like:

```
from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)

quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
```

Test Plan:
python test/quantization/test_qat.py -k test_quantize_api_convert_path

ghstack-source-id: 219f16040a01376d4c363a2164324554cd493656
Pull Request resolved: #1540
  • Loading branch information
andrewor14 committed Jan 10, 2025
1 parent b5b739b commit 3324655
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 2 deletions.
65 changes: 65 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from torchao.quantization.qat.embedding import (
Expand All @@ -42,6 +43,9 @@
_GenericFakeQuantize,
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
)
from torchao.quantization.quant_primitives import (
MappingType,
TorchAODType,
Expand Down Expand Up @@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self):
lambda m, _: isinstance(m, torch.nn.ReLU),
)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_convert_path(self):
"""
Test that the following:
quantize_(model, intx_quantization_aware_training(...))
quantize_(model, from_intx_quantization_aware_training(...))
quantize_(model, int8_dynamic_activation_int4_weight())
can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert.
"""
from torchao.quantization.qat import (
Int8DynActInt4WeightQATQuantizer,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M()
baseline_model = copy.deepcopy(m)

# Baseline prepare
baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size)
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ prepare
activation_config = FakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
intx_quantization_aware_training(activation_config, weight_config),
)

# Compare prepared values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

# Baseline convert
baseline_model = baseline_quantizer.convert(baseline_model)

# quantize_ convert
quantize_(m, from_intx_quantization_aware_training())
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

# Compare converted values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
from .embedding import (
Expand All @@ -18,4 +19,5 @@
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
]
40 changes: 38 additions & 2 deletions torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, List, Optional, Union
from typing import Any, Callable, List, Optional, Union

import torch

Expand Down Expand Up @@ -242,7 +242,7 @@ def __setattr__(self, name: str, value: Any):
def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> torch.nn.Module:
) -> Callable:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Expand Down Expand Up @@ -295,6 +295,42 @@ def _insert_fake_quantize(mod: torch.nn.Module):
return _insert_fake_quantize


def from_intx_quantization_aware_training() -> Callable:
"""
Return a function that converts a model with fake quantized modules,
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
back to model with the original, corresponding modules without
fake quantization. This should be used with
:func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
from torchao.quantization import quantize_
quantize_(
model_with_fake_quantized_linears,
from_intx_quantization_aware_training(),
)
"""

def _remove_fake_quantize(mod: torch.nn.Module):
"""
If the given module is a fake quantized module, return the original
corresponding version of the module without fake quantization.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, FakeQuantizedLinear):
return mod.to_linear()
elif isinstance(mod, FakeQuantizedEmbedding):
return mod.to_embedding()
else:
return mod

return _remove_fake_quantize


class ComposableQATQuantizer(TwoStepQuantizer):
"""
Composable quantizer that users can use to apply multiple QAT quantizers easily.
Expand Down
18 changes: 18 additions & 0 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.sparse,
)

def to_embedding(self) -> torch.nn.Embedding:
new_embedding = torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
device=self.weight.device,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_embedding.weight = self.weight
return new_embedding

@classmethod
def from_embedding(
cls,
Expand Down
14 changes: 14 additions & 0 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight
return F.linear(x, w)

def to_linear(self) -> torch.nn.Linear:
new_linear = torch.nn.Linear(
self.in_features,
self.out_features,
self.bias,
device=self.weight.device
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_linear.weight = self.weight
return new_linear

@classmethod
def from_linear(
cls,
Expand Down

0 comments on commit 3324655

Please sign in to comment.