From b5a444a3ec5fcd45fe86175256d1ab862c64fcb0 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Wed, 26 Jun 2024 11:18:35 -0700 Subject: [PATCH] Add more compile compatibility for Float8Tensor ops (#285) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/285 Reviewed By: vkuzo Differential Revision: D59068281 Pulled By: drisspg fbshipit-source-id: 18fa34db74cf60e85ff372ff1091c107119403a0 --- float8_experimental/float8_ops.py | 55 +++++++++++++++++++++++++++++++ test/test_base.py | 38 +++++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index ea2cb67..3a50cc8 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -42,6 +42,9 @@ def decorator(func): aten.as_strided.default, aten.clone.default, aten.detach.default, + aten.slice.Tensor, + aten.transpose.int, + aten.fill_.Scalar, ] ) def float8_desugar_op(aten_op, args, kwargs=None): @@ -263,3 +266,55 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): return Float8Tensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config ) + + +@implements([aten.index_put_.default]) +def index_put_fp8(aten_op, args, kwargs=None): + fp8_self = args[0] + fp8_values = args[2] + assert isinstance(fp8_self, Float8Tensor) + assert isinstance(fp8_values, Float8Tensor) + assert fp8_self._scale == fp8_values._scale + assert fp8_self.dtype == fp8_values.dtype + assert fp8_self._orig_dtype == fp8_values._orig_dtype + + fp8_data = fp8_self._data + fp8_values_data = fp8_values._data + fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) + return Float8Tensor( + fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + ) + + +@implements([aten.copy_.default]) +def copy_fp8(aten_op, args, kwargs=None): + # For a copy op with Float8Tensors involved, only the following combinations are allowed: + # 1. self is a high precision (hp) tensor, src is a Float8Tensor: + # in this case src is upcasted and unscaled to go into the hp tensor + # 2. self and src are Float8Tensors: + # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) + # Every other combination is banned as the semantics are not well defined + + self = args[0] + src = args[1] + + if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + src_hp = src.to_original_precision() + return aten_op(self, src_hp, *args[2:], **kwargs) + elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + assert ( + self._orig_dtype == src._orig_dtype + ), "Expecting both Float8Tensors to be of the same dtype" + assert ( + self._scale == src._scale + ), "Expecting both Float8Tensors to have thee same scale" + assert ( + self._mm_config == src._mm_config + ), "Expecting both Float8Tensors to have thee same mm config" + assert ( + self._data.dtype == src._data.dtype + ), "Expecting both Float8Tensors to be of the same dtypet" + fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) + return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) + else: + raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/test/test_base.py b/test/test_base.py index b688ccb..742a4b1 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -83,6 +83,44 @@ def test_split_cat(self): catted = torch.cat(splits, dim=0) assert bitwise_identical(fp8_a, catted) + def test_index_put(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + + index = torch.randint(0, 15, (16,), dtype=torch.long) + + b = torch.rand(16, 16, dtype=torch.bfloat16) + scale_b = tensor_to_scale(b, torch.float8_e4m3fn) + fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn) + fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn) + + with self.assertRaises(AssertionError): + b[index] = fp8_a + fp8_b[index] = a + fp8_b_bad[index] = fp8_a + fp8_b[index] = fp8_a + + def test_copy_(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + + b = torch.empty(16, dtype=torch.bfloat16) + b.copy_(fp8_a) # Should work + torch.testing.assert_close(b, fp8_a.to_original_precision()) + with self.assertRaises(RuntimeError): + fp8_a.copy_(b) # Should fail + + fp8_b = Float8Tensor( + torch.empty(16, dtype=torch.float8_e4m3fn), + scale_a, + torch.bfloat16, + fp8_a._mm_config, + ) + fp8_b.copy_(fp8_a) + torch.testing.assert_close(fp8_a._data, fp8_b._data) + class TestFloat8Linear: def _test_linear_impl(