From 60f6ca910be0e570894e1eb94ecacc7231dc38f5 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Fri, 17 May 2024 10:41:41 -0400 Subject: [PATCH 1/8] few fixes for FMS --- float8_experimental/__init__.py | 4 +- float8_experimental/float8_linear.py | 114 +++++++++++++++++++++++++++ float8_experimental/float8_ops.py | 39 +++++++++ float8_experimental/float8_tensor.py | 2 +- 4 files changed, 156 insertions(+), 3 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 72c09052..d30e71d6 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. # Lets define a few top level things here -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, Float8DASWLinear, Float8SWLinear from float8_experimental.float8_tensor import Float8Tensor -__all__ = ["Float8Tensor", "Float8Linear"] +__all__ = ["Float8Tensor", "Float8Linear", "Float8DASWLinear", "Float8SWLinear"] diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 5120f36d..eb18edf6 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -340,3 +340,117 @@ def from_float(cls, mod, emulate: bool = False): # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod + +class Float8DASWLinear(Float8LinearMixin, torch.nn.Linear): + """ + A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks + scales in way friendly to delayed scaling. + """ + + def forward(self, x): + self.float8_pre_forward(x) + + x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) + + # convert weight tensor to fp8 ahead of use + if not hasattr(self, 'w_fp8_t'): + self.w_fp8_t = self.cast_w_to_float8(self.weight, self.is_amax_initialized).t() + # Release fp16 memory + del self.weight + + y = torch.matmul(x_fp8, self.w_fp8_t) # matmul expects both inputs to be Float8Tensor + + # Cast gradY to float8_e5m2 during backward + y = self.cast_y_to_float8_in_bw(y, self.emulate) # Never backward for our use case + + if self.bias is not None: + y = y + self.bias.to(y.dtype) + + self.float8_post_forward() + return y + + @classmethod + def from_float(cls, mod, emulate: bool = False): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + emulate (bool): whether to emulate fp8 matmul logic in float32 + """ + # TODO Follow up! This is a great idea but we need the mixin base to create real + # Tensors and the Linear base to create empty params + # with torch.device("meta"): + new_mod = cls(mod.in_features, mod.out_features, bias=False) + new_mod.weight = mod.weight + new_mod.bias = mod.bias + new_mod.emulate = emulate + # I think its okay to send all params and buffers to device + new_mod.to(mod.weight.device) + return new_mod + + +# Mauricio's Implementation +class Float8SWLinear(torch.nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(Float8SWLinear, self).__init__(in_features=out_features, out_features=in_features, bias=bias) + self.w_f8 = None + self.w_inv_s = None + self.biasfp16 = None + self.finfo = torch.finfo(torch.float8_e4m3fn) + + @classmethod + def from_float(cls, mod, emulate: bool = False): + """ + Create an nn.Linear with fp8 compute from a regular nn.Linear + + Args: + mod (torch.nn.Linear): nn.Linear to convert + emulate (bool): whether to emulate fp8 matmul logic in float32 + """ + # TODO Follow up! This is a great idea but we need the mixin base to create real + # Tensors and the Linear base to create empty params + # with torch.device("meta"): + new_mod = cls(mod.in_features, mod.out_features, bias=False) + #new_mod.weight = mod.weight + #new_mod.bias = mod.bias + new_mod.emulate = emulate + # I think its okay to send all params and buffers to device + new_mod.to(mod.weight.device) + w_f8, w_inv_s = new_mod.to_float8(mod.weight) + new_mod.w_f8 = w_f8.t() + new_mod.w_inv_s = w_inv_s + # Release fp16 memory + del new_mod.weight + + if mod.bias is not None: + new_mod.biasfp16 = mod.bias.to(torch.float16) + mod.weight = None + mod.bias = None + return new_mod + + def to_float8(self, x): + dtype = torch.float8_e4m3fn + #finfo = torch.finfo(torch.float8_e4m3fn) + # Calculate the scale as dtype max divided by absmax + scale = self.finfo.max / x.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + x_scl_sat = (x * scale).clamp(min=self.finfo.min, max=self.finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + return x_scl_sat.to(dtype), scale.float().reciprocal() + + + def forward(self, x): + # create test inputs + #x_f8, x_inv_s = self.to_float8(x.to(torch.float16)) + x_f8 = x.to(torch.float8_e4m3fn) + # perform the float8 matmul + ishape= list(x_f8.shape) + x_f8_mat = x_f8.view(-1,ishape[-1]) + y, _ = torch._scaled_mm(x_f8_mat, self.w_f8, out_dtype=torch.bfloat16, + scale_b=self.w_inv_s, bias=self.biasfp16, use_fast_accum=False)#, scale_a=x_inv_s) + y = y.view(ishape[0],ishape[1],-1) + return y \ No newline at end of file diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index e22ccf3b..8b64ba04 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -41,6 +41,10 @@ def decorator(func): aten.as_strided.default, aten.clone.default, aten.detach.default, + aten.slice.Tensor, + aten.transpose.int, + aten.fill_.Scalar, + aten.copy_.default, ] ) def float8_desugar_op(aten_op, args, kwargs=None): @@ -254,3 +258,38 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): return Float8Tensor( fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config ) + +@implements([aten.index_put_.default]) +def index_put_fp8(aten_op, args, kwargs=None): + fp8_self = args[0] + fp8_values = args[2] + assert isinstance(fp8_self, Float8Tensor) + assert isinstance(fp8_values, Float8Tensor) + + fp8_data = fp8_self._data + fp8_values_data = fp8_values._data + fp8_out = aten_op(fp8_data, args[1], fp8_values_data, *args[3:], **kwargs) + return Float8Tensor( + fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config + ) + +@implements([aten.copy_.default]) +def copy_fp8(aten_op, args, kwargs=None): + self = args[0] + src = args[1] + assert isinstance(self, Float8Tensor) or isinstance(src, Float8Tensor) + + self_data = self + if isinstance(self, Float8Tensor): + self_data = self._data + + src_data = src + if isinstance(src, Float8Tensor): + src_data = src._data + + fp8_out = aten_op(self_data, src_data, *args[2:], **kwargs) + if isinstance(self, Float8Tensor): + return Float8Tensor( + fp8_out, self._scale, self._orig_dtype, self._mm_config + ) + return fp8_out diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 2535b69c..dd97d0cc 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -157,7 +157,7 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to(tensor._orig_dtype) @staticmethod def backward(ctx, g): From 97cc11eef703685d7c0eccc0da67273e0136861d Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Thu, 30 May 2024 15:55:02 -0400 Subject: [PATCH 2/8] Disable check --- float8_experimental/float8_ops.py | 6 +++--- float8_experimental/float8_tensor.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 8b64ba04..3c415882 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -84,9 +84,9 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._orig_dtype == orig_dtype ), "Expecting all chunks to be of the same dtype" - assert ( - chunk._scale is scale - ), "Expecting all chunks to have thee same scale as a result of a split" + # assert ( + # chunk._scale is scale + # ), "Expecting all chunks to have thee same scale as a result of a split" assert ( chunk._mm_config is mm_config ), "Expecting all chunks to have thee same mm config as a result of a split" diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index dd97d0cc..b9d0b588 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -157,6 +157,7 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): + # return torch.ops.fp8_fast.copy_and_scale(tensor._data, tensor._orig_dtype, tensor._scale) return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to(tensor._orig_dtype) @staticmethod From 6f402bb8d7dda0fce756d8313a814260c3b2adc1 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Thu, 30 May 2024 15:57:27 -0400 Subject: [PATCH 3/8] fix --- float8_experimental/float8_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 3c415882..8b9e1b51 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -87,9 +87,9 @@ def float8_cat(aten_op, args, kwargs=None): # assert ( # chunk._scale is scale # ), "Expecting all chunks to have thee same scale as a result of a split" - assert ( - chunk._mm_config is mm_config - ), "Expecting all chunks to have thee same mm config as a result of a split" + # assert ( + # chunk._mm_config is mm_config + # ), "Expecting all chunks to have thee same mm config as a result of a split" assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" From 2065e368d172c51bef30d14617be30a983633809 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Fri, 14 Jun 2024 19:40:06 -0400 Subject: [PATCH 4/8] remove Jamie PR --- float8_experimental/__init__.py | 4 +- float8_experimental/float8_linear.py | 114 --------------------------- float8_experimental/float8_ops.py | 14 ++-- float8_experimental/float8_tensor.py | 3 +- 4 files changed, 10 insertions(+), 125 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index d30e71d6..72c09052 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. # Lets define a few top level things here -from float8_experimental.float8_linear import Float8Linear, Float8DASWLinear, Float8SWLinear +from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_tensor import Float8Tensor -__all__ = ["Float8Tensor", "Float8Linear", "Float8DASWLinear", "Float8SWLinear"] +__all__ = ["Float8Tensor", "Float8Linear"] diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index eb18edf6..5120f36d 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -340,117 +340,3 @@ def from_float(cls, mod, emulate: bool = False): # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod - -class Float8DASWLinear(Float8LinearMixin, torch.nn.Linear): - """ - A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks - scales in way friendly to delayed scaling. - """ - - def forward(self, x): - self.float8_pre_forward(x) - - x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized) - - # convert weight tensor to fp8 ahead of use - if not hasattr(self, 'w_fp8_t'): - self.w_fp8_t = self.cast_w_to_float8(self.weight, self.is_amax_initialized).t() - # Release fp16 memory - del self.weight - - y = torch.matmul(x_fp8, self.w_fp8_t) # matmul expects both inputs to be Float8Tensor - - # Cast gradY to float8_e5m2 during backward - y = self.cast_y_to_float8_in_bw(y, self.emulate) # Never backward for our use case - - if self.bias is not None: - y = y + self.bias.to(y.dtype) - - self.float8_post_forward() - return y - - @classmethod - def from_float(cls, mod, emulate: bool = False): - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - emulate (bool): whether to emulate fp8 matmul logic in float32 - """ - # TODO Follow up! This is a great idea but we need the mixin base to create real - # Tensors and the Linear base to create empty params - # with torch.device("meta"): - new_mod = cls(mod.in_features, mod.out_features, bias=False) - new_mod.weight = mod.weight - new_mod.bias = mod.bias - new_mod.emulate = emulate - # I think its okay to send all params and buffers to device - new_mod.to(mod.weight.device) - return new_mod - - -# Mauricio's Implementation -class Float8SWLinear(torch.nn.Linear): - def __init__(self, in_features, out_features, bias=True): - super(Float8SWLinear, self).__init__(in_features=out_features, out_features=in_features, bias=bias) - self.w_f8 = None - self.w_inv_s = None - self.biasfp16 = None - self.finfo = torch.finfo(torch.float8_e4m3fn) - - @classmethod - def from_float(cls, mod, emulate: bool = False): - """ - Create an nn.Linear with fp8 compute from a regular nn.Linear - - Args: - mod (torch.nn.Linear): nn.Linear to convert - emulate (bool): whether to emulate fp8 matmul logic in float32 - """ - # TODO Follow up! This is a great idea but we need the mixin base to create real - # Tensors and the Linear base to create empty params - # with torch.device("meta"): - new_mod = cls(mod.in_features, mod.out_features, bias=False) - #new_mod.weight = mod.weight - #new_mod.bias = mod.bias - new_mod.emulate = emulate - # I think its okay to send all params and buffers to device - new_mod.to(mod.weight.device) - w_f8, w_inv_s = new_mod.to_float8(mod.weight) - new_mod.w_f8 = w_f8.t() - new_mod.w_inv_s = w_inv_s - # Release fp16 memory - del new_mod.weight - - if mod.bias is not None: - new_mod.biasfp16 = mod.bias.to(torch.float16) - mod.weight = None - mod.bias = None - return new_mod - - def to_float8(self, x): - dtype = torch.float8_e4m3fn - #finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax - scale = self.finfo.max / x.abs().max().clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - x_scl_sat = (x * scale).clamp(min=self.finfo.min, max=self.finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - return x_scl_sat.to(dtype), scale.float().reciprocal() - - - def forward(self, x): - # create test inputs - #x_f8, x_inv_s = self.to_float8(x.to(torch.float16)) - x_f8 = x.to(torch.float8_e4m3fn) - # perform the float8 matmul - ishape= list(x_f8.shape) - x_f8_mat = x_f8.view(-1,ishape[-1]) - y, _ = torch._scaled_mm(x_f8_mat, self.w_f8, out_dtype=torch.bfloat16, - scale_b=self.w_inv_s, bias=self.biasfp16, use_fast_accum=False)#, scale_a=x_inv_s) - y = y.view(ishape[0],ishape[1],-1) - return y \ No newline at end of file diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 8b9e1b51..6aee8fa4 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -84,12 +84,12 @@ def float8_cat(aten_op, args, kwargs=None): assert ( chunk._orig_dtype == orig_dtype ), "Expecting all chunks to be of the same dtype" - # assert ( - # chunk._scale is scale - # ), "Expecting all chunks to have thee same scale as a result of a split" - # assert ( - # chunk._mm_config is mm_config - # ), "Expecting all chunks to have thee same mm config as a result of a split" + assert ( + chunk._scale is scale + ), "Expecting all chunks to have thee same scale as a result of a split" + assert ( + chunk._mm_config is mm_config + ), "Expecting all chunks to have thee same mm config as a result of a split" assert ( chunk._data.dtype == fp8_dtype ), "Expecting all chunks to be of the same dtype as a result of a split" @@ -282,7 +282,7 @@ def copy_fp8(aten_op, args, kwargs=None): self_data = self if isinstance(self, Float8Tensor): self_data = self._data - + src_data = src if isinstance(src, Float8Tensor): src_data = src._data diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index b9d0b588..13c273e4 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -157,8 +157,7 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - # return torch.ops.fp8_fast.copy_and_scale(tensor._data, tensor._orig_dtype, tensor._scale) - return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to(tensor._orig_dtype) + return (tensor._data.to(tensor._orig_dtype) / tensor._scale.to(tensor._org_dtype)) @staticmethod def backward(ctx, g): From eec175a72d30b999bef148c1e7cf43fc5da0d7c3 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Thu, 20 Jun 2024 11:53:21 -0400 Subject: [PATCH 5/8] Add test for index_put --- float8_experimental/float8_ops.py | 9 ++++++--- test/test_base.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index f0c4e3b2..a0a36c72 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -257,12 +257,16 @@ def wait_tensor_fp8(aten_op, args, kwargs=None): fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config ) + @implements([aten.index_put_.default]) def index_put_fp8(aten_op, args, kwargs=None): fp8_self = args[0] fp8_values = args[2] assert isinstance(fp8_self, Float8Tensor) assert isinstance(fp8_values, Float8Tensor) + assert fp8_self._scale == fp8_values._scale + assert fp8_self.dtype == fp8_values.dtype + assert fp8_self._orig_dtype == fp8_values._orig_dtype fp8_data = fp8_self._data fp8_values_data = fp8_values._data @@ -271,6 +275,7 @@ def index_put_fp8(aten_op, args, kwargs=None): fp8_out, fp8_self._scale, fp8_self._orig_dtype, fp8_self._mm_config ) + @implements([aten.copy_.default]) def copy_fp8(aten_op, args, kwargs=None): self = args[0] @@ -287,7 +292,5 @@ def copy_fp8(aten_op, args, kwargs=None): fp8_out = aten_op(self_data, src_data, *args[2:], **kwargs) if isinstance(self, Float8Tensor): - return Float8Tensor( - fp8_out, self._scale, self._orig_dtype, self._mm_config - ) + return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) return fp8_out diff --git a/test/test_base.py b/test/test_base.py index 6e7a34cc..7ad42f25 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -81,6 +81,24 @@ def test_split_cat(self): catted = torch.cat(splits, dim=0) assert bitwise_identical(fp8_a, catted) + def test_index_put(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + + index = torch.randint(0, 15, (16,), dtype=torch.long) + + b = torch.rand(16, 16, dtype=torch.bfloat16) + scale_b = tensor_to_scale(b, torch.float8_e4m3fn) + fp8_b = Float8Tensor.to_float8(b, scale_a, torch.float8_e4m3fn) + fp8_b_bad = Float8Tensor.to_float8(b, scale_b, torch.float8_e4m3fn) + + with self.assertRaises(AssertionError): + b[index] = fp8_a + fp8_b[index] = a + fp8_b_bad[index] = fp8_a + fp8_b[index] = fp8_a + class TestFloat8Linear: def _test_linear_impl( From 35a5dfcbe2600569f6dcfd0cddd593727fca2caa Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Thu, 20 Jun 2024 13:11:14 -0400 Subject: [PATCH 6/8] revert change to scale dtype as it seems to work now --- float8_experimental/float8_ops.py | 1 - float8_experimental/float8_tensor.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index a0a36c72..2bf80cfc 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -44,7 +44,6 @@ def decorator(func): aten.slice.Tensor, aten.transpose.int, aten.fill_.Scalar, - aten.copy_.default, ] ) def float8_desugar_op(aten_op, args, kwargs=None): diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 13c273e4..2535b69c 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -157,7 +157,7 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return (tensor._data.to(tensor._orig_dtype) / tensor._scale.to(tensor._org_dtype)) + return tensor._data.to(tensor._orig_dtype) / tensor._scale @staticmethod def backward(ctx, g): From 30e63498c04c96be3665bf25b86183df24275922 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Tue, 25 Jun 2024 10:56:41 -0400 Subject: [PATCH 7/8] Add tests and improve copy_ semantics --- float8_experimental/float8_ops.py | 38 +++++++++++++++++++++---------- test/test_base.py | 20 ++++++++++++++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 2bf80cfc..853d5315 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -277,19 +277,33 @@ def index_put_fp8(aten_op, args, kwargs=None): @implements([aten.copy_.default]) def copy_fp8(aten_op, args, kwargs=None): + # For a copy op with Float8Tensors involved, only the following combinations are allowed: + # 1. self is a high precision (hp) tensor, src is a Float8Tensor: + # in this case src is upcasted and unscaled to go into the hp tensor + # 2. self and src are Float8Tensors: + # the copy is only allowed if all the Float8Tensor properties are equal (a la torch.cat) + # Every other combination is banned as the semantics are not well defined + self = args[0] src = args[1] - assert isinstance(self, Float8Tensor) or isinstance(src, Float8Tensor) - - self_data = self - if isinstance(self, Float8Tensor): - self_data = self._data - src_data = src - if isinstance(src, Float8Tensor): - src_data = src._data - - fp8_out = aten_op(self_data, src_data, *args[2:], **kwargs) - if isinstance(self, Float8Tensor): + if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + src_hp = src.to_original_precision() + return aten_op(self, src_hp, *args[2:], **kwargs) + elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor): + assert ( + self._orig_dtype == src._orig_dtype + ), "Expecting both Float8Tensors to be of the same dtype" + assert ( + self._scale is src._scale + ), "Expecting both Float8Tensors to have thee same scale" + assert ( + self._mm_config is src._mm_config + ), "Expecting both Float8Tensors to have thee same mm config" + assert ( + self._data.dtype == src._data.dtype + ), "Expecting both Float8Tensors to be of the same dtypet" + fp8_out = aten_op(self._data, src._data, *args[2:], **kwargs) return Float8Tensor(fp8_out, self._scale, self._orig_dtype, self._mm_config) - return fp8_out + else: + raise RuntimeError("Unsupported semantics for copy_ in Float8Tensor") diff --git a/test/test_base.py b/test/test_base.py index 7ad42f25..654801e1 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -99,6 +99,26 @@ def test_index_put(self): fp8_b_bad[index] = fp8_a fp8_b[index] = fp8_a + def test_copy_(self): + a = torch.rand(16, dtype=torch.bfloat16) + scale_a = tensor_to_scale(a, torch.float8_e4m3fn) + fp8_a = Float8Tensor.to_float8(a, scale_a, torch.float8_e4m3fn) + + b = torch.empty(16, dtype=torch.bfloat16) + b.copy_(fp8_a) # Should work + torch.testing.assert_close(b, fp8_a.to_original_precision()) + with self.assertRaises(RuntimeError): + fp8_a.copy_(b) # Should fail + + fp8_b = Float8Tensor( + torch.empty(16, dtype=torch.float8_e4m3fn), + scale_a, + torch.bfloat16, + fp8_a._mm_config, + ) + fp8_b.copy_(fp8_a) + torch.testing.assert_close(fp8_a._data, fp8_b._data) + class TestFloat8Linear: def _test_linear_impl( From 0469357180f38cfe1a081dde063b0b2ec1b1b6d6 Mon Sep 17 00:00:00 2001 From: Antoni Vros Date: Tue, 25 Jun 2024 16:15:07 -0400 Subject: [PATCH 8/8] Update test for copy --- float8_experimental/float8_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py index 853d5315..770fb3b5 100644 --- a/float8_experimental/float8_ops.py +++ b/float8_experimental/float8_ops.py @@ -295,10 +295,10 @@ def copy_fp8(aten_op, args, kwargs=None): self._orig_dtype == src._orig_dtype ), "Expecting both Float8Tensors to be of the same dtype" assert ( - self._scale is src._scale + self._scale == src._scale ), "Expecting both Float8Tensors to have thee same scale" assert ( - self._mm_config is src._mm_config + self._mm_config == src._mm_config ), "Expecting both Float8Tensors to have thee same mm config" assert ( self._data.dtype == src._data.dtype