Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add more compile compatibility for Float8Tensor ops #285

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def decorator(func):
aten.as_strided.default,
aten.clone.default,
aten.detach.default,
aten.slice.Tensor,
aten.transpose.int,
aten.fill_.Scalar,
]
)
def float8_desugar_op(aten_op, args, kwargs=None):
Expand Down Expand Up @@ -252,3 +255,55 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)


@implements([aten.index_put_.default])
def index_put_fp8(aten_op, args, kwargs=None):
fp8_self = args[0]
fp8_values = args[2]
assert isinstance(fp8_self, Float8Tensor)
assert isinstance(fp8_values, Float8Tensor)
assert fp8_self._scale == fp8_values._scale
assert fp8_self.dtype == fp8_values.dtype
assert fp8_self._orig_dtype == fp8_values._orig_dtype

fp8_data = fp8_self._data
fp8_values_data = fp8_values._data
fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs)
return Float8Tensor(
fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config
)


@implements([aten.copy_.default])
def copy_fp8(aten_op, args, kwargs=None):
# For a copy op with Float8Tensors involved, only the following combinations are allowed:
# 1. self is a high precision (hp) tensor, src is a Float8Tensor:
# in this case src is upcasted and unscaled to go into the hp tensor
# 2. self and src are Float8Tensors:
# the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat)
# Every other combination is banned as the semantics are not well defined

self = args[0]
src = args[1]

if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
src_hp = src.to_original_precision()
return aten_op(self, src_hp, *args[2:], **kwargs)
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
assert (
self._orig_dtype == src._orig_dtype
), "Expecting both Float8Tensors to be of the same dtype"
assert (
self._scale == src._scale
), "Expecting both Float8Tensors to have thee same scale"
assert (
self._mm_config == src._mm_config
), "Expecting both Float8Tensors to have thee same mm config"
assert (
self._data.dtype == src._data.dtype
), "Expecting both Float8Tensors to be of the same dtypet"
fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs)
return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config)
else:
raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor")
38 changes: 38 additions & 0 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,44 @@ def test_split_cat(self):
catted = torch.cat(splits, dim=0)
assert bitwise_identical(fp8_a, catted)

def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn)

with self.assertRaises(AssertionError):
b[index] = fp8_a
fp8_b[index] = a
fp8_b_bad[index] = fp8_a
fp8_b[index] = fp8_a

def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
torch.testing.assert_close(b, fp8_a.to_original_precision())
with self.assertRaises(RuntimeError):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
torch.empty(16, dtype=torch.float8_e4m3fn),
scale_a,
torch.bfloat16,
fp8_a._mm_config,
)
fp8_b.copy_(fp8_a)
torch.testing.assert_close(fp8_a._data, fp8_b._data)


class TestFloat8Linear:
def _test_linear_impl(
Expand Down
Loading