Skip to content

Commit

Permalink
Clean up linear_int8_dynamic_activation_intx_weight_subclass
Browse files Browse the repository at this point in the history
Summary:
Cleans up layout and quantization API:

```
int8_dynamic_activation_intx_weight(
    group_size: int = 128,
    bit_width: int = 4,
    has_weight_zeros: bool = False,
    weight_mapping_type=MappingType.ASYMMETRIC,
    act_mapping_type=MappingType.ASYMMETRIC,
    layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
)
```

int8_dynamic_activation_intx_weight is now very similar to int8_dynamic_activation_int4_weight.  By passing bit_width=4, has_weight_zeros=false, and  layout=PlainLayout(), it should be numerically identical (but slower).

The fallback option is removed and instead relies on using PlainLayout().

Reviewed By: jerryzh168

Differential Revision: D67821939
  • Loading branch information
metascroy authored and facebook-github-bot committed Jan 11, 2025
1 parent ad61822 commit 8752c4c
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 482 deletions.
20 changes: 3 additions & 17 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,32 +543,18 @@ def ffn_or_attn_only(mod, fqn):
from torchao.experimental.quant_api import (
int8_dynamic_activation_intx_weight,
)

assert (
precision == torch.float32
), "int8_dynamic_activation_intx_weight requires fp32 precision"

try:
torch.ops.torchao._pack_8bit_act_4bit_weight
except:
print(
"Unable to load experimental torchao kernels. Performance will be slow."
)
print(
"To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU"
)
assert precision == torch.float32, "int8_dynamic_activation_intx_weight requires using precision=torch.float32"

# Quantize model
_quant_args = quantization.split("-")
nbit = int(_quant_args[1])
assert nbit >= 1 and nbit <= 8, "nbits must be 1 to 8"
bit_width = int(_quant_args[1])
group_size = int(_quant_args[2])
has_weight_zeros = bool(_quant_args[3])
quantize_(
model,
int8_dynamic_activation_intx_weight(
bit_width=bit_width,
group_size=group_size,
nbit=nbit,
has_weight_zeros=has_weight_zeros,
),
)
Expand Down
18 changes: 11 additions & 7 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
Expand All @@ -55,7 +55,7 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
Expand All @@ -64,6 +64,8 @@ def __init__(
self._layout = _layout

def __tensor_flatten__(self):
if self.zero_point is None:
return ["int_data", "scale"], [self._layout]
return ["int_data", "scale", "zero_point"], [self._layout]

@classmethod
Expand All @@ -73,7 +75,7 @@ def __tensor_unflatten__(
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)
Expand All @@ -83,15 +85,17 @@ def to(self, *args, **kwargs):
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.zero_point.to(kwargs["device"])
if self.zero_point is not None
else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

Expand Down Expand Up @@ -134,7 +138,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
Expand All @@ -148,7 +152,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
Expand Down
6 changes: 3 additions & 3 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -87,7 +87,7 @@ class AQTTensorImpl(TorchAOBaseTensor):
the underlying implementation of a AQT based on layout
"""

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Get the plain (unpacked) Tensor for the tensor impl
Returns data, scale and zero_point
Expand All @@ -103,7 +103,7 @@ def from_plain(
cls,
data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
"""Construct a TensorImpl from data, scale, zero_point and the _layout"""
Expand Down
Loading

0 comments on commit 8752c4c

Please sign in to comment.