Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into add_h100_ci
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Jan 24, 2025
2 parents 7e2cc00 + 4ed93b9 commit 316e342
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 14 deletions.
22 changes: 22 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,28 @@ def reset_memory():
assert param.is_cuda
self.assertLess(memory_streaming, memory_baseline)

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
def test_int4wo_cpu(self, dtype, x_dim):
from torchao.dtypes import Int4CPULayout

device = "cpu"
m = ToyLinearModel().eval().to(dtype).to(device)
example_inputs = m.example_inputs(dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)

with torch.no_grad():
quantize_(m, int4_weight_only(group_size=32, layout=Int4CPULayout()))
# ensure the expected op is in the code
_, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
8 changes: 8 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
_linear_fp_act_int4_weight_gemlite_check,
_linear_fp_act_int4_weight_gemlite_impl,
)
from torchao.dtypes.uintx.int4_cpu_layout import (
_linear_fp_act_uint4_weight_cpu_check,
_linear_fp_act_uint4_weight_cpu_impl,
)
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
Expand Down Expand Up @@ -151,6 +155,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int4_weight_cutlass_check,
_linear_int8_act_int4_weight_cutlass_impl,
),
(
_linear_fp_act_uint4_weight_cpu_check,
_linear_fp_act_uint4_weight_cpu_impl,
),
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

Expand Down
86 changes: 82 additions & 4 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import register_layout
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device
from torchao.quantization.quant_primitives import ZeroPointDomain
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
Expand Down Expand Up @@ -126,7 +133,7 @@ def from_plain(
zero_point = zero_point.reshape(int_data.shape[0], -1)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
return cls(packed_weight, scale_and_zero, False, _layout)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -231,7 +238,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
groupsize = int(original_shape[1] / scale.shape[-2])
block_size = (1, groupsize)
device = self.device
original_dtype = torch.bfloat16
original_dtype = self.scale_and_zero.dtype
target_dtype = torch.int32
quant_min = 0
quant_max = 15
Expand Down Expand Up @@ -261,3 +268,74 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

def get_layout(self) -> Layout:
return self._layout


def _aqt_is_uint4(aqt):
"""Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
return (
aqt.tensor_impl.dtype == torch.uint8
and aqt.quant_min == 0
and aqt.quant_max == 15
)


def _is_float(dtype):
return dtype in (torch.float, torch.half, torch.bfloat16)


def _linear_fp_act_uint4_weight_cpu_check(input_tensor, weight_tensor, bias):
return (
TORCH_VERSION_AT_LEAST_2_6
and is_device(input_tensor.device.type, "cpu")
and is_device(weight_tensor.device.type, "cpu")
and (bias is None or is_device(bias.device.type, "cpu"))
and not is_traceable_wrapper_subclass(input_tensor)
and _is_float(input_tensor.dtype)
and isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_uint4(weight_tensor)
and _is_float(weight_tensor.dtype)
and len(weight_tensor.shape) == 2
and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT
and isinstance(weight_tensor._layout, Int4CPULayout)
)


def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias):
assert (
TORCH_VERSION_AT_LEAST_2_6
), f"Requires PyTorch version at least 2.6, but got: {torch.__version__}"
assert is_device(
input_tensor.device.type, "cpu"
), f"For CPU device only but got: {input_tensor.device}"
assert (
weight_tensor.block_size[0] == 1
), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
assert input_tensor.shape[-1] == weight_tensor.shape[1], (
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
)

act_mat = input_tensor
packed_weight = weight_tensor.tensor_impl.packed_weight
scale_and_zero = weight_tensor.tensor_impl.scale_and_zero

orig_act_size = act_mat.size()
orig_dtype = act_mat.dtype

# reshape to 2D
act_mat = act_mat.reshape(-1, act_mat.shape[-1])

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)

# remove out_feature padding
orig_out_features = weight_tensor.shape[-2]
y = y[:, :orig_out_features]
y = y.reshape(*orig_act_size[:-1], orig_out_features)

if bias is not None:
y += bias
return y.to(orig_dtype)
12 changes: 3 additions & 9 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
fill_defaults,
find_multiple,
)
Expand Down Expand Up @@ -76,14 +75,9 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):

# groupwise int4 quantization
groupsize = weight_tensor.block_size[1]
if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6:
y = torch.ops.aten._weight_int4pack_mm_for_cpu(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)
else:
y = torch.ops.aten._weight_int4pack_mm(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)
y = torch.ops.aten._weight_int4pack_mm(
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
)

# remove out_feature padding
orig_out_features = weight_tensor.shape[-2]
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,9 @@ def apply_int4_weight_only_quant(weight):
quant_max = 15
eps = 1e-6
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
zero_point_dtype = torch.bfloat16
zero_point_dtype = (
weight.dtype if isinstance(layout, Int4CPULayout) else torch.bfloat16
)

nonlocal zero_point_domain
assert (
Expand Down

0 comments on commit 316e342

Please sign in to comment.