From a0070271e3fdce384b4c6cd07cf88f91d2cd99e5 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 30 Apr 2024 18:11:22 -0700 Subject: [PATCH] decorator for args check Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_nf4.py | 27 ++++++++++------- torchao/dtypes/nf4tensor.py | 59 +++++++++++++++++++++++++++++-------- 2 files changed, 64 insertions(+), 22 deletions(-) diff --git a/test/dtypes/test_nf4.py b/test/dtypes/test_nf4.py index b7e9f08af7..500ea38880 100644 --- a/test/dtypes/test_nf4.py +++ b/test/dtypes/test_nf4.py @@ -11,7 +11,12 @@ parametrize, run_tests, ) -from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4 +from torchao.dtypes.nf4tensor import ( + linear_nf4, + NF4Tensor, + to_nf4, + INNER_TENSOR_NAMES_FOR_FSDP, +) import torch.nn.functional as F import io from collections import OrderedDict @@ -270,7 +275,7 @@ def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]): def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]): nf4_tensor = to_nf4(torch.randn(input_size)) nf4_tensor_zeros = nf4_tensor.new_zeros(input_size) - for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: + for attr in INNER_TENSOR_NAMES_FOR_FSDP: inner_tensor = getattr(nf4_tensor_zeros, attr) self.assertEqual(torch.count_nonzero(inner_tensor), 0) expected_size = input_size if not isinstance(input_size, int) else (input_size, ) @@ -305,7 +310,7 @@ def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]): def test_tensor_slice_1d_invalid(self): nf4_tensor = to_nf4(torch.randn(512 * 512)) - with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"): + with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with customized step"): nf4_tensor[..., ::2] with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"): nf4_tensor[1:] @@ -327,17 +332,19 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]): viewed_tensor = nf4_tensor.view(-1) self.assertEqual(viewed_tensor.dim(), 1) self.assertEqual(viewed_tensor.numel(), math.prod(input_size)) - for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: + for attr in INNER_TENSOR_NAMES_FOR_FSDP: inner_tensor = getattr(viewed_tensor, attr) self.assertEqual(inner_tensor.size(0), inner_tensor.numel()) @parametrize("input_size", [(512 * 512,), (512, 512)]) def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]): nf4_tensor = to_nf4(torch.randn(input_size)) - with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): - nf4_tensor.view(input_size) - with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): - nf4_tensor.view(input_size) + if len(input_size) == 1: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"): + nf4_tensor.view(input_size) + if len(input_size) == 2: + with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"): + nf4_tensor.view(input_size) @parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)]) def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): @@ -346,7 +353,7 @@ def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]): self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size()) self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride()) self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset()) - for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: + for attr in INNER_TENSOR_NAMES_FOR_FSDP: inner_tensor_orig = getattr(nf4_tensor, attr) inner_tensor_strided = getattr(nf4_tensor_strided, attr) self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size()) @@ -406,7 +413,7 @@ def test_to_cpu(self): nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda')) nf4_tensor = nf4_tensor.cpu() self.assertEqual(nf4_tensor.device.type, "cpu") - for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]: + for attr in INNER_TENSOR_NAMES_FOR_FSDP: inner_tensor = getattr(nf4_tensor, attr) self.assertEqual(inner_tensor.device.type, "cpu") diff --git a/torchao/dtypes/nf4tensor.py b/torchao/dtypes/nf4tensor.py index 06c9858003..32b634c7e4 100644 --- a/torchao/dtypes/nf4tensor.py +++ b/torchao/dtypes/nf4tensor.py @@ -3,6 +3,7 @@ from typing import Dict, Tuple import math import sys +from enum import Enum, auto import torch import torch.nn.functional as F @@ -82,6 +83,43 @@ def call_from_inner_tensors(nf4tensor: "NF4Tensor", method_name: str, args, kwar attr_to_tensor[attr] = func(*args, **kwargs) return attr_to_tensor +class CompareOp(Enum): + EQ = auto() + LT = auto() + +def expect_num_of_args(op: CompareOp, num: int, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args) < num): + raise NotImplementedError(msg) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_arg_value_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.EQ and not (args[k] == value): + raise NotImplementedError(msg + str(args[k])) + return func(aten_op, args, kwargs) + return wrapper + return decorator + +def expect_args_len_at_k(k: int, op: CompareOp, value: Any, msg: str): + def decorator(func): + @functools.wraps(func) + def wrapper(aten_op, args, kwargs=None): + if op == CompareOp.LT and not (len(args[k]) < value): + raise NotImplementedError(msg + str(len(args[k]))) + elif op == CompareOp.EQ and not (len(args[k]) == value): + raise NotImplementedError(msg + str(len(args[k]))) + return func(aten_op, args, kwargs) + return wrapper + return decorator + + @implements([torch.ops.aten.detach]) def noop_detach(func, *args, **kwargs): return args[0][0] @@ -146,11 +184,12 @@ def nf4_split(aten_op, args, kwargs=None): aten.new_zeros.default, ] ) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=") def nf4_new_zeros(aten_op, args, kwargs=None): nf4tensor = args[0] new_size = tuple(args[1]) new_size_dim = len(new_size) - if (not new_size_dim in [1, 2]) or nf4tensor.numel() % math.prod(new_size) != 0: + if nf4tensor.numel() % math.prod(new_size) != 0: raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}") ratio = nf4tensor.numel() // math.prod(new_size) @@ -180,14 +219,11 @@ def nf4_new_zeros(aten_op, args, kwargs=None): aten.slice.Tensor, ] ) +@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step") +@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=") +@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=") def nf4_slice(aten_op, args, kwargs=None): nf4tensor = args[0] - if len(args) == 5: - raise NotImplementedError(f"aten.slice(NF4Tensor) with step={args[4]}") - if not args[1] == 0: - raise NotImplementedError(f"aten.slice(NF4Tensor) with dim={args[1]}") - if not args[2] == 0: - raise NotImplementedError(f"aten.slice(NF4Tensor) with start={args[2]}") # for tensor 512 x 512, tensor[:, :512] dispatch to # aten.slice(dim = 0, end=sys.maxsize) if not args[3] in [nf4tensor.size(0), sys.maxsize]: @@ -199,11 +235,12 @@ def nf4_slice(aten_op, args, kwargs=None): aten.view.default, ] ) +@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=") def nf4_view(aten_op, args, kwargs=None): nf4tensor = args[0] size = args[1] - if not (len(size) == 1 and size[0] == -1): - raise NotImplementedError(f"aten.view(NF4Tensor) with size {size}") + if size[0] != -1: + raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}") updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs) updated_attrs.update({ "size": [nf4tensor.numel()], @@ -216,13 +253,12 @@ def nf4_view(aten_op, args, kwargs=None): aten.as_strided.default, ] ) +@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=") def nf4_as_strided(aten_op, args, kwargs=None): nf4tensor = args[0] size = args[1] stride = tuple(args[2]) storage_offset = args[3] - if not len(size) <= 2: - raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support dim <= 2 but got dim={len(size)}") if math.prod(size) != nf4tensor.numel(): raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}") if stride != make_contiguous_strides_for(size): @@ -289,7 +325,6 @@ def mm_default(func, *args, **kwargs): ] ) def copy_(func, *args, **kwargs): - assert len(args[0]) == 2 and len(kwargs) == 0, "only support aten.copy_.default with 2 args" original: NF4Tensor = args[0][0] copy_in: torch.Tensor = args[0][1]