From 01ad96aebec4377643a4fccd3a5c30c00b3e7744 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 14 Nov 2024 12:09:24 -0800 Subject: [PATCH] Fix an autoquant bug in flatten/unflatten Summary: att Test Plan: python test/integration/test_integration.py -k test_autoquantizable_flatten_unflatten Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 12 ++++++++++++ torchao/quantization/autoquant.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 92d2dcd5c2..3279489543 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -777,6 +777,18 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype): AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype ) + def test_autoquantizable_flatten_unflatten(self): + from torchao.quantization import DEFAULT_AUTOQUANT_CLASS_LIST + weight = torch.randn(16, 32) + qtensor_class_list = DEFAULT_AUTOQUANT_CLASS_LIST + aqw = AutoQuantizableLinearWeight.from_float(weight, qtensor_class_list) + tensor_data_name_dict, tensor_attributes = aqw.__tensor_flatten__() + tensor_data_dict = {name: getattr(aqw, name) for name in tensor_data_name_dict} + outer_size = aqw.size() + outer_stride = aqw.stride() + reconstructed = type(aqw).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") @unittest.skipIf(not is_H100, "Need H100 to run") diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index ee6bf98852..87cb5e2655 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -217,7 +217,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None ): weight = tensor_data_dict["weight"] - qtensor_class_list, mode, dtype, shape = tensor_attributes[0] + qtensor_class_list, mode, dtype, shape = tensor_attributes return cls( weight, qtensor_class_list,