Skip to content

Commit

Permalink
include kwargs
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Apr 19, 2024
1 parent b2c3c02 commit 9be2de3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def test_smoketest_linear(self, dtype: torch.dtype):
def test_smoketest_linear_compile(self, dtype: torch.dtype):
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0) and dtype == torch.bfloat16:
self.skipTest("test requires SM capability of at least (8, 0).")
if version.parse(torch.__version__) <= version.parse("2.2.2"):
self.skipTest("test requires 2.3.0+ for tracing NF4Tensor")
if version.parse(torch.__version__) < version.parse("2.3.0"):
self.skipTest("test requires 2.3.0 and above for tracing NF4Tensor")
a = torch.randn(32, 32, dtype=dtype, device='cuda')
a_nf4 = torchao.dtypes.to_nf4(a, 16, 2)
inp = torch.randn(2, 32, 32, dtype=a.dtype, device=a.device)
Expand Down
10 changes: 9 additions & 1 deletion torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,12 @@ def decorator(func):

@implements_torch_function(torch.Tensor.to)
def function_to_dtype(*args, **kwargs):
return args[0].get_original_weight().to(args[1])
if isinstance(args[0], NF4Tensor) and isinstance(args[1], torch.dtype):
# Tensor.to(dtype, non_blocking, copy, memory_format)
return args[0].get_original_weight().to(*args[1:], **kwargs)
else:
# Tensor.to(device, dtype, non_blocking, copy, memory_format)
# Tensor.to(other, non_blocking, copy)
raise NotImplementedError(
f"NF4Tensor.to({args[1:]}, {kwargs}) is not supported"
)

0 comments on commit 9be2de3

Please sign in to comment.