diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 8b465e71..4c66abf8 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -6,11 +6,12 @@ """ A wrapper around a `torch.nn.Linear` module which does fp8 compute. """ -import torch from typing import Optional +import torch + from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd -from float8_experimental.float8_utils import IS_AMD, tensor_to_scale, FP8Dtypes +from float8_experimental.float8_utils import FP8Dtypes, tensor_to_scale @torch._dynamo.allow_in_graph @@ -21,12 +22,7 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function): """ @staticmethod - def forward( - ctx, - tensor, - emulate: bool, - fp8_dtype_bw: torch.dtype - ): + def forward(ctx, tensor, emulate: bool, fp8_dtype_bw: torch.dtype): ctx.emulate = emulate ctx.fp8_dtype_bw = fp8_dtype_bw return tensor @@ -34,7 +30,9 @@ def forward( @staticmethod def backward(ctx, gradY): gradY_scale = tensor_to_scale(gradY, ctx.fp8_dtype_bw) - fp8_tensor = to_fp8_no_autograd(gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate) + fp8_tensor = to_fp8_no_autograd( + gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate + ) return fp8_tensor, None, None @@ -63,7 +61,9 @@ class Float8DynamicLinear(torch.nn.Linear): conversion to fp8 of the input and weight tensors. """ - def __init__(self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs): + def __init__( + self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs + ): """ Args: use_activation_hooks (bool): whether to use activation hooks for casting to and from float8 @@ -120,7 +120,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor: @classmethod def from_float( - cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None + cls, + mod, + emulate: bool = False, + use_activation_hooks: bool = False, + fp8_dtypes: Optional[FP8Dtypes] = None, ) -> "Float8DynamicLinear": """ Create an nn.Linear with fp8 compute from a regular nn.Linear diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 73577016..5fc1feab 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -14,7 +14,7 @@ import dataclasses -from typing import Optional, Literal +from typing import Literal, Optional import float8_experimental.config as config @@ -26,8 +26,8 @@ amax_history_to_scale, E4M3_MAX_POS, E5M2_MAX_POS, + FP8Dtypes, tensor_to_amax, - FP8Dtypes ) @@ -316,7 +316,13 @@ def forward(self, input): return y @classmethod - def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None): + def from_float( + cls, + mod, + emulate: bool = False, + use_activation_hooks: bool = False, + fp8_dtypes: Optional[FP8Dtypes] = None, + ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 3d676cb7..1c62b624 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -230,7 +230,7 @@ def to_float8( float8_dtype: torch.dtype, amax_buffer: Optional[torch.Tensor] = None, emulate: bool = False, - )-> "Float8Tensor": + ) -> "Float8Tensor": """Converts a higher precision tensor to float8 in a differentiable way. Args: diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index b2a7d07b..a8e0fea7 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Literal from dataclasses import dataclass +from typing import Literal import torch import torch.distributed as dist @@ -30,10 +30,12 @@ @dataclass(frozen=True) class FP8Dtypes: - """ Defines the fp8 dtypes to be used in forward and backwrad computations""" + """Defines the fp8 dtypes to be used in forward and backwrad computations""" + fp8_dtype_fw: torch.dtype = torch.float8_e4m3fn fp8_dtype_bw: torch.dtype = torch.float8_e5m2 + @torch.no_grad() def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype diff --git a/test/test_base.py b/test/test_base.py index fba57ccb..6e1741a2 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -8,6 +8,7 @@ import unittest import warnings from typing import Optional + import pytest import torch @@ -24,17 +25,16 @@ from float8_experimental.float8_python_api import mm_float8 from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import ( - E5M2_FNUZ_MAX_POS, amax_to_scale, compute_error, - E4M3_MAX_POS, E4M3_FNUZ_MAX_POS, - E5M2_MAX_POS, + E4M3_MAX_POS, E5M2_FNUZ_MAX_POS, + E5M2_MAX_POS, FP16_MAX_POS, - tensor_to_scale, - IS_AMD, FP8Dtypes, + IS_AMD, + tensor_to_scale, ) random.seed(0) @@ -65,9 +65,10 @@ def _test_linear_impl( emulate: bool, use_activation_hooks: bool, fp8_dtypes: Optional[FP8Dtypes] = None, - ): - m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes) + m_fp8 = get_float8_linear( + linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes + ) for _ in range(2): if linear_requires_sync(linear_type): sync_float8_amax_and_scale_history(m_fp8) @@ -95,7 +96,12 @@ def _test_linear_impl( ] for buffer_name in amax_buffer_names: buffer_value = getattr(m_fp8, buffer_name) - for init_val in (E4M3_MAX_POS, E5M2_MAX_POS, E4M3_FNUZ_MAX_POS, E5M2_FNUZ_MAX_POS): + for init_val in ( + E4M3_MAX_POS, + E5M2_MAX_POS, + E4M3_FNUZ_MAX_POS, + E5M2_FNUZ_MAX_POS, + ): assert torch.ne( buffer_value, torch.tensor(init_val) ), f"{buffer_name} not filled, current value {buffer_value}" @@ -147,10 +153,16 @@ def test_linear_nobias( f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" ) pytest.skip() - fp8_dtypes = FP8Dtypes() if not IS_AMD else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) + fp8_dtypes = ( + FP8Dtypes() + if not IS_AMD + else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) + ) x = torch.randn(*x_shape, device="cuda") m_ref = nn.Linear(16, 32, bias=False, device="cuda") - self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes) + self._test_linear_impl( + x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes + ) @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) @pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])