Skip to content

Commit

Permalink
torch.chunk and cpu offloading ops
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 27, 2024
1 parent 613bf67 commit 3933bfa
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 43 deletions.
117 changes: 96 additions & 21 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import unittest
from packaging import version
import math

import torch
from torch import nn
Expand Down Expand Up @@ -251,10 +252,17 @@ def test_torch_chunk_valid(self, input_size: Union[Tuple[int], int]):
for chunk in chunks:
self.assertEqual(chunk.size(0), expected_size0)

@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512), (512, 512, 512)])
def test_torch_chunk_invalid(self, input_size: Union[Tuple[int], int]):
@parametrize("input_size", [511 * 512, (511 * 512,), (511, 512)])
def test_torch_chunk_invalid_divide(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaises(AssertionError):
with self.assertRaisesRegex(AssertionError, "Number of scalers must be divisible by scaler block size"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

@parametrize("input_size", [(512, 512, 512)])
def test_torch_chunk_invalid_3d(self, input_size: Union[Tuple[int], int]):
num_chunks = 2
with self.assertRaisesRegex(AssertionError, "expect input tensor dim <= 2"):
nf4_tensor = to_nf4(torch.randn(input_size))
torch.chunk(nf4_tensor, num_chunks)

Expand All @@ -277,55 +285,122 @@ def test_tensor_new_zeros_invalid(self, input_size: Union[Tuple[int], int]):
else:
new_size = (input_size[0] + 1, input_size[1])
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\(NF4Tensor\) with new size"):
with self.assertRaisesRegex(NotImplementedError, "aten.new_zeros\\(NF4Tensor\\) with new size"):
nf4_tensor_zeros = nf4_tensor.new_zeros(new_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_slice_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
orig_attrs, _ = nf4_tensor.__tensor_flatten__()
orig_sizes = dict([(attr, getattr(nf4_tensor, attr).size()) for attr in orig_attrs])
end_idx = input_size if isinstance(input_size, int) else input_size[0]
sliced_tensor = nf4_tensor[:end_idx]
self.assertEqual(nf4_tensor.size(), sliced_tensor.size())
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
self.assertEqual(getattr(sliced_tensor, attr).untyped_storage().data_ptr(), orig_storage)
sliced_tensor_inner = getattr(sliced_tensor, attr)
self.assertEqual(sliced_tensor_inner.untyped_storage().data_ptr(), orig_storage)
self.assertEqual(sliced_tensor_inner.size(), orig_sizes[attr])

def test_tensor_slice_1d_invalid(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with step"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with step"):
nf4_tensor[..., ::2]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end "):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end "):
nf4_tensor[:2]

def test_tensor_slice_2d_invalid(self):
nf4_tensor = to_nf4(torch.randn((512, 512)))
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with dim"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with dim"):
nf4_tensor[:, :511]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with start"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with start"):
nf4_tensor[1:]
with self.assertRaisesRegex(NotImplementedError, "aten.slice\(NF4Tensor\) with end"):
with self.assertRaisesRegex(NotImplementedError, "aten.slice\\(NF4Tensor\\) with end"):
nf4_tensor[:2]

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
viewed_tensor = nf4_tensor.view(-1)
self.asssertEqual(viewed_tensor.dim(), 1)
self.asssertEqual(viewed_tensor.numel(), math.prod(input_size))
attrs, _ = sliced_tensor.__tensor_flatten__()
for attr in attrs:
orig_storage = getattr(nf4_tensor, attr).untyped_storage().data_ptr()
inner_tensor = getattr(sliced_tensor, attr)
self.asssertEqual(inner_tensor.dim(), 1)
self.assertEqual(inner_tensor.untyped_storage().data_ptr(), orig_storage)
self.assertEqual(viewed_tensor.dim(), 1)
self.assertEqual(viewed_tensor.numel(), math.prod(input_size))
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor = getattr(viewed_tensor, attr)
self.assertEqual(inner_tensor.size(0), inner_tensor.numel())

@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_view_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)
with self.assertRaisesRegex(NotImplementedError, "aten.view\\(NF4Tensor\\) with size"):
nf4_tensor.view(input_size)

@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_as_strided_valid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
nf4_tensor_strided = torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), nf4_tensor.storage_offset())
self.assertEqual(nf4_tensor_strided.size(), nf4_tensor.size())
self.assertEqual(nf4_tensor_strided.stride(), nf4_tensor.stride())
self.assertEqual(nf4_tensor_strided.storage_offset(), nf4_tensor.storage_offset())
for attr in ["quantized_scalers", "quantization_factor", "quantized_data"]:
inner_tensor_orig = getattr(nf4_tensor, attr)
inner_tensor_strided = getattr(nf4_tensor_strided, attr)
self.assertEqual(inner_tensor_strided.size(), inner_tensor_orig.size())
self.assertEqual(inner_tensor_strided.stride(), inner_tensor_orig.stride())
self.assertEqual(inner_tensor_strided.storage_offset(), inner_tensor_orig.storage_offset())


# def test_tensor_as_strided(self):
# pass
@parametrize("input_size", [(512 * 512,), (512, 512)])
def test_tensor_as_strided_invalid(self, input_size: Union[Tuple[int], int]):
nf4_tensor = to_nf4(torch.randn(input_size))
if len(input_size) == 1:
size = (input_size[0] - 1, )
else:
size = (input_size[0] - 1, input_size[1])
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) different numel"):
torch.as_strided(nf4_tensor, size, nf4_tensor.stride(), nf4_tensor.storage_offset())
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support original storage offset"):
torch.as_strided(nf4_tensor, nf4_tensor.size(), nf4_tensor.stride(), 1)

if len(input_size) == 2:
with self.assertRaisesRegex(NotImplementedError, "aten.as_strided\\(NF4Tensor\\) only support continuous stride"):
stride = (nf4_tensor.stride()[1], nf4_tensor.stride()[0])
torch.as_strided(nf4_tensor, nf4_tensor.size(), stride, nf4_tensor.storage_offset())

def test_pin_memory(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertFalse(nf4_tensor.is_pinned())

nf4_tensor = nf4_tensor.pin_memory()
self.assertTrue(nf4_tensor.is_pinned())

nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
self.assertFalse(nf4_tensor.is_pinned())

def test_to_cuda(self):
nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", non_blocking=True)
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda")
self.assertEqual(nf4_tensor.device.type, "cuda")

nf4_tensor = to_nf4(torch.randn(512 * 512))
self.assertEqual(nf4_tensor.device.type, "cpu")
nf4_tensor = nf4_tensor.to("cuda", torch.bfloat16)
self.assertEqual(nf4_tensor.device.type, "cuda")
self.assertEqual(nf4_tensor.dtype, torch.bfloat16)

def test_to_cpu(self):
nf4_tensor = to_nf4(torch.randn(512 * 512, device='cuda'))
nf4_tensor.cpu()

instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)
Expand Down
150 changes: 128 additions & 22 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from torch import Tensor
from torch.distributed.device_mesh import DeviceMesh
from torch._prims_common import make_contiguous_strides_for


aten = torch.ops.aten
Expand Down Expand Up @@ -215,7 +216,9 @@ def nf4_slice(aten_op, args, kwargs=None):
]
)
def nf4_view(aten_op, args, kwargs=None):
assert len(args) == 2 and len(args[1]) == 1 and args[1][0] == -1
size = args[1]
if not (len(size) == 1 and size[0] == -1):
raise NotImplementedError(f"aten.view(NF4Tensor) with size {size}")
quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs)
quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs)
quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs)
Expand Down Expand Up @@ -245,14 +248,22 @@ def nf4_view(aten_op, args, kwargs=None):
]
)
def nf4_as_strided(aten_op, args, kwargs=None):
assert len(args[1]) == 2 and math.prod(args[1]) == args[0].numel(), "only support same numel"
assert args[2] == [args[1][1], 1], f"only support stride {[args[1][1], 1]}"
assert args[0].storage_offset() == args[3], f"only support same storage offset"
size = args[1]
stride = tuple(args[2])
storage_offset = args[3]
if not len(size) <= 2:
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support dim <= 2 but got dim={len(size)}")
if math.prod(size) != args[0].numel():
raise NotImplementedError(f"aten.as_strided(NF4Tensor) different numel={args[0].numel()} and size={size}")
if stride != make_contiguous_strides_for(size):
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support continuous stride={make_contiguous_strides_for(size)} but got stride={stride}")
if args[0].storage_offset() != storage_offset:
raise NotImplementedError(f"aten.as_strided(NF4Tensor) only support original storage offset {args[0].storage_offset()} but got {storage_offset}")
return NF4Tensor(
SubclassTensorArgs(
torch.Size(args[1]),
tuple(args[2]),
args[0].storage_offset(),
torch.Size(size),
stride,
storage_offset,
args[0].dtype,
args[0].device,
args[0].requires_grad,
Expand Down Expand Up @@ -330,25 +341,23 @@ def nf4_copy_(aten_op, args, kwargs=None):
quantized_data = aten_op(original.quantized_data, copy_in.quantized_data, **kwargs)
scaler_mean = aten_op(original.scaler_mean, copy_in.scaler_mean, **kwargs)
nf4 = aten_op(original.nf4, copy_in.nf4, **kwargs)
tensor_meta = SubclassTensorArgs(
if (
original.size(),
original.stride(),
original.storage_offset(),
original.dtype,
original.device,
original.requires_grad
) != (
copy_in.size(),
copy_in.stride(),
copy_in.storage_offset(),
copy_in.dtype,
copy_in.device,
copy_in.requires_grad,
)
return NF4Tensor(
tensor_meta,
copy_in.block_size,
copy_in.n_blocks,
copy_in.scaler_block_size,
quantized_scalers,
quantization_factor,
scaler_mean,
quantized_data,
nf4,
)
copy_in.requires_grad
):
raise NotImplementedError(f"aten.copy_(NF4Tensor) with different metadata")
return original

# Convert Non NF4Tensor into NF4 for copy in
if not isinstance(copy_in, NF4Tensor):
Expand All @@ -365,6 +374,48 @@ def nf4_copy_(aten_op, args, kwargs=None):
return original.copy_(same_meta_nf4)


@implements(
[
aten.is_pinned.default,
]
)
def nf4_is_pinned(aten_op, args, kwargs=None):
return (
aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs) and
aten_op(args[0].quantization_factor, *(args[1:]), **kwargs) and
aten_op(args[0].quantized_data, *(args[1:]), **kwargs)
)


@implements(
[
aten._pin_memory.default,
]
)
def nf4_pin_memory(aten_op, args, kwargs=None):
quantized_scalers = aten_op(args[0].quantized_scalers, *(args[1:]), **kwargs)
quantization_factor = aten_op(args[0].quantization_factor, *(args[1:]), **kwargs)
quantized_data = aten_op(args[0].quantized_data, *(args[1:]), **kwargs)
return NF4Tensor(
SubclassTensorArgs(
args[0].size(),
args[0].stride(),
args[0].storage_offset(),
args[0].dtype,
args[0].device,
args[0].requires_grad,
),
args[0].block_size,
args[0].n_blocks,
args[0].scaler_block_size,
quantized_scalers,
quantization_factor,
args[0].scaler_mean,
quantized_data,
args[0].nf4,
)


@dataclass
class SubclassTensorArgs:
original_shape: torch.Size
Expand Down Expand Up @@ -469,7 +520,7 @@ def from_tensor(
block_size: int,
scaler_block_size: int,
):
assert inpt_tensor.dim() <= 2
assert inpt_tensor.dim() <= 2, f"expect input tensor dim <= 2 but got dim = {inpt_tensor.dim()}"
assert (
inpt_tensor.numel() % block_size == 0
), f"Input tensor must be divisible by block size, got {inpt_tensor.numel()} and {block_size}"
Expand Down Expand Up @@ -921,9 +972,64 @@ def function_to_dtype(*args, **kwargs):
if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return args[0].get_original_weight().to(*args[1:], **kwargs)
elif isinstance(args[0], NF4Tensor) and (
isinstance(args[1], torch.device) or (
isinstance(args[1], str) and (
args[1] == "cpu" or args[1].startswith("cuda")
)
)
) and len(args) == 2:
# Tensor.to(device, non_blocking)
device = args[1]
quantized_scalers = args[0].quantized_scalers.to(*(args[1:]), **kwargs)
quantization_factor = args[0].quantization_factor.to(*(args[1:]), **kwargs)
quantized_data = args[0].quantized_data.to(*(args[1:]), **kwargs)
return NF4Tensor(
SubclassTensorArgs(
args[0].size(),
args[0].stride(),
args[0].storage_offset(),
args[0].dtype,
device,
args[0].requires_grad,
),
args[0].block_size,
args[0].n_blocks,
args[0].scaler_block_size,
quantized_scalers,
quantization_factor,
args[0].scaler_mean,
quantized_data,
args[0].nf4,
)
else:
# Tensor.to(device, dtype, non_blocking, copy, memory_format)
# Tensor.to(other, non_blocking, copy)
raise NotImplementedError(
f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported, passing to dispatch"
)


@implements_torch_function(torch.Tensor.cpu)
def function_cpu(*args, **kwargs):
quantized_scalers = args[0].quantized_scalers.cpu(*(args[1:]), **kwargs)
quantization_factor = args[0].quantization_factor.cpu(*(args[1:]), **kwargs)
quantized_data = args[0].quantized_data.cpu(*(args[1:]), **kwargs)
return NF4Tensor(
SubclassTensorArgs(
args[0].size(),
args[0].stride(),
args[0].storage_offset(),
args[0].dtype,
'cpu',
args[0].requires_grad,
),
args[0].block_size,
args[0].n_blocks,
args[0].scaler_block_size,
quantized_scalers,
quantization_factor,
args[0].scaler_mean,
quantized_data,
args[0].nf4,
)

0 comments on commit 3933bfa

Please sign in to comment.