Skip to content

Commit

Permalink
decorator for args check
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed May 1, 2024
1 parent e36ab6c commit a007027
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 22 deletions.
27 changes: 17 additions & 10 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
parametrize,
run_tests,
)
from torchao.dtypes.nf4tensor import linear_nf4, NF4Tensor, to_nf4
from torchao.dtypes.nf4tensor import (
linear_nf4,
NF4Tensor,
to_nf4,
INNER_TENSOR_NAMES_FOR_FSDP,
)
import torch.nn.functional as F
import io
from collections import OrderedDict
Expand Down Expand Up @@ -270,7 +275,7 @@ def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]):
def test_tensor_new_zeros_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_zeros = nf4_tensor.new_zeros(input_size)
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
for attr in INNER_TENSOR_NAMES_FOR_FSDP:
inner_tensor = getattr(nf4_tensor_zeros, attr)
self.assertEqual(torch.count_nonzero(inner_tensor), 0)
expected_size = input_size if not isinstance(input_size, int) else (input_size, )
Expand Down Expand Up @@ -305,7 +310,7 @@ def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with customized step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
Expand All @@ -327,17 +332,19 @@ def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
viewed_tensor = nf4_tensor.view(-1)
self.assertEqual(viewed_tensor.dim(), 1)
self.assertEqual(viewed_tensor.numel(), math.prod(input_size))
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
for attr in INNER_TENSOR_NAMES_FOR_FSDP:
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
if len(input_size) == 1:
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with len\\(size\\)"):
nf4_tensor.view(input_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
Expand All @@ -346,7 +353,7 @@ def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size())
self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride())
self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset())
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
for attr in INNER_TENSOR_NAMES_FOR_FSDP:
inner_tensor_orig = getattr(nf4_tensor, attr)
inner_tensor_strided = getattr(nf4_tensor_strided, attr)
self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size())
Expand Down Expand Up @@ -406,7 +413,7 @@ def test_to_cpu(self):
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
nf4_tensor = nf4_tensor.cpu()
self.assertEqual(nf4_tensor.device.type, "cpu")
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
for attr in INNER_TENSOR_NAMES_FOR_FSDP:
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")

Expand Down
59 changes: 47 additions & 12 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Tuple
import math
import sys
from enum import Enum, auto

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -82,6 +83,43 @@ def call_from_inner_tensors(nf4tensor: "NF4Tensor", method_name: str, args, kwar
attr_to_tensor[attr] = func(*args, **kwargs)
return attr_to_tensor

class CompareOp(Enum):
EQ = auto()
LT = auto()

def expect_num_of_args(op: CompareOp, num: int, msg: str):
def decorator(func):
@functools.wraps(func)
def wrapper(aten_op, args, kwargs=None):
if op == CompareOp.LT and not (len(args) < num):
raise NotImplementedError(msg)
return func(aten_op, args, kwargs)
return wrapper
return decorator

def expect_arg_value_at_k(k: int, op: CompareOp, value: Any, msg: str):
def decorator(func):
@functools.wraps(func)
def wrapper(aten_op, args, kwargs=None):
if op == CompareOp.EQ and not (args[k] == value):
raise NotImplementedError(msg + str(args[k]))
return func(aten_op, args, kwargs)
return wrapper
return decorator

def expect_args_len_at_k(k: int, op: CompareOp, value: Any, msg: str):
def decorator(func):
@functools.wraps(func)
def wrapper(aten_op, args, kwargs=None):
if op == CompareOp.LT and not (len(args[k]) < value):
raise NotImplementedError(msg + str(len(args[k])))
elif op == CompareOp.EQ and not (len(args[k]) == value):
raise NotImplementedError(msg + str(len(args[k])))
return func(aten_op, args, kwargs)
return wrapper
return decorator


@implements([torch.ops.aten.detach])
def noop_detach(func, *args, **kwargs):
return args[0][0]
Expand Down Expand Up @@ -146,11 +184,12 @@ def nf4_split(aten_op, args, kwargs=None):
aten.new_zeros.default,
]
)
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.view(NF4Tensor) with len(size)=")
def nf4_new_zeros(aten_op, args, kwargs=None):
nf4tensor = args[0]
new_size = tuple(args[1])
new_size_dim = len(new_size)
if (not new_size_dim in [1, 2]) or nf4tensor.numel() % math.prod(new_size) != 0:
if nf4tensor.numel() % math.prod(new_size) != 0:
raise NotImplementedError(f"aten.new_zeros(NF4Tensor) with new size {new_size}")
ratio = nf4tensor.numel() // math.prod(new_size)

Expand Down Expand Up @@ -180,14 +219,11 @@ def nf4_new_zeros(aten_op, args, kwargs=None):
aten.slice.Tensor,
]
)
@expect_num_of_args(CompareOp.LT, 5, "aten.slice(NF4Tensor) with customized step")
@expect_arg_value_at_k(1, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with dim=")
@expect_arg_value_at_k(2, CompareOp.EQ, 0, "aten.slice(NF4Tensor) with start=")
def nf4_slice(aten_op, args, kwargs=None):
nf4tensor = args[0]
if len(args) == 5:
raise NotImplementedError(f"aten.slice(NF4Tensor) with step={args[4]}")
if not args[1] == 0:
raise NotImplementedError(f"aten.slice(NF4Tensor) with dim={args[1]}")
if not args[2] == 0:
raise NotImplementedError(f"aten.slice(NF4Tensor) with start={args[2]}")
# for tensor 512 x 512, tensor[:, :512] dispatch to
# aten.slice(dim = 0, end=sys.maxsize)
if not args[3] in [nf4tensor.size(0), sys.maxsize]:
Expand All @@ -199,11 +235,12 @@ def nf4_slice(aten_op, args, kwargs=None):
aten.view.default,
]
)
@expect_args_len_at_k(1, CompareOp.EQ, 1, "aten.view(NF4Tensor) with len(size)=")
def nf4_view(aten_op, args, kwargs=None):
nf4tensor = args[0]
size = args[1]
if not (len(size) == 1 and size[0] == -1):
raise NotImplementedError(f"aten.view(NF4Tensor) with size {size}")
if size[0] != -1:
raise NotImplementedError(f"aten.view(NF4Tensor) with size={size}")
updated_attrs = apply_to_inner_tensors(nf4tensor, aten_op, args[1:], kwargs)
updated_attrs.update({
"size": [nf4tensor.numel()],
Expand All @@ -216,13 +253,12 @@ def nf4_view(aten_op, args, kwargs=None):
aten.as_strided.default,
]
)
@expect_args_len_at_k(1, CompareOp.LT, 3, "aten.as_strided(NF4Tensor) only support dim <= 2 but got dim=")
def nf4_as_strided(aten_op, args, kwargs=None):
nf4tensor = args[0]
size = args[1]
stride = tuple(args[2])
storage_offset = args[3]
if not len(size) <= 2:
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support dim <= 2 but got dim={len(size)}")
if math.prod(size) != nf4tensor.numel():
raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={nf4tensor.numel()} and size={size}")
if stride != make_contiguous_strides_for(size):
Expand Down Expand Up @@ -289,7 +325,6 @@ def mm_default(func, *args, **kwargs):
]
)
def copy_(func, *args, **kwargs):
assert len(args[0]) == 2 and len(kwargs) == 0, "only support aten.copy_.default with 2 args"
original: NF4Tensor = args[0][0]
copy_in: torch.Tensor = args[0][1]

Expand Down

0 comments on commit a007027

Please sign in to comment.