Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Mar 8, 2024
1 parent 4a27a27 commit fb3d4ce
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 27 deletions.
26 changes: 15 additions & 11 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
"""
import torch
from typing import Optional

import torch

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_utils import IS_AMD, tensor_to_scale, FP8Dtypes
from float8_experimental.float8_utils import FP8Dtypes, tensor_to_scale


@torch._dynamo.allow_in_graph
Expand All @@ -21,20 +22,17 @@ class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
"""

@staticmethod
def forward(
ctx,
tensor,
emulate: bool,
fp8_dtype_bw: torch.dtype
):
def forward(ctx, tensor, emulate: bool, fp8_dtype_bw: torch.dtype):
ctx.emulate = emulate
ctx.fp8_dtype_bw = fp8_dtype_bw
return tensor

@staticmethod
def backward(ctx, gradY):
gradY_scale = tensor_to_scale(gradY, ctx.fp8_dtype_bw)
fp8_tensor = to_fp8_no_autograd(gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate)
fp8_tensor = to_fp8_no_autograd(
gradY, gradY_scale, ctx.fp8_dtype_bw, ctx.emulate
)
return fp8_tensor, None, None


Expand Down Expand Up @@ -63,7 +61,9 @@ class Float8DynamicLinear(torch.nn.Linear):
conversion to fp8 of the input and weight tensors.
"""

def __init__(self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs):
def __init__(
self, use_activation_hooks: bool, fp8_dtype: FP8Dtypes, **super_kwargs
):
"""
Args:
use_activation_hooks (bool): whether to use activation hooks for casting to and from float8
Expand Down Expand Up @@ -120,7 +120,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:

@classmethod
def from_float(
cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None
cls,
mod,
emulate: bool = False,
use_activation_hooks: bool = False,
fp8_dtypes: Optional[FP8Dtypes] = None,
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand Down
12 changes: 9 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dataclasses

from typing import Optional, Literal
from typing import Literal, Optional

import float8_experimental.config as config

Expand All @@ -26,8 +26,8 @@
amax_history_to_scale,
E4M3_MAX_POS,
E5M2_MAX_POS,
FP8Dtypes,
tensor_to_amax,
FP8Dtypes
)


Expand Down Expand Up @@ -316,7 +316,13 @@ def forward(self, input):
return y

@classmethod
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False, fp8_dtypes: Optional[FP8Dtypes] = None):
def from_float(
cls,
mod,
emulate: bool = False,
use_activation_hooks: bool = False,
fp8_dtypes: Optional[FP8Dtypes] = None,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def to_float8(
float8_dtype: torch.dtype,
amax_buffer: Optional[torch.Tensor] = None,
emulate: bool = False,
)-> "Float8Tensor":
) -> "Float8Tensor":
"""Converts a higher precision tensor to float8 in a differentiable way.
Args:
Expand Down
6 changes: 4 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Literal
from dataclasses import dataclass
from typing import Literal

import torch
import torch.distributed as dist
Expand All @@ -30,10 +30,12 @@

@dataclass(frozen=True)
class FP8Dtypes:
""" Defines the fp8 dtypes to be used in forward and backwrad computations"""
"""Defines the fp8 dtypes to be used in forward and backwrad computations"""

fp8_dtype_fw: torch.dtype = torch.float8_e4m3fn
fp8_dtype_bw: torch.dtype = torch.float8_e5m2


@torch.no_grad()
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
Expand Down
32 changes: 22 additions & 10 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest
import warnings
from typing import Optional

import pytest

import torch
Expand All @@ -24,17 +25,16 @@
from float8_experimental.float8_python_api import mm_float8
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import (
E5M2_FNUZ_MAX_POS,
amax_to_scale,
compute_error,
E4M3_MAX_POS,
E4M3_FNUZ_MAX_POS,
E5M2_MAX_POS,
E4M3_MAX_POS,
E5M2_FNUZ_MAX_POS,
E5M2_MAX_POS,
FP16_MAX_POS,
tensor_to_scale,
IS_AMD,
FP8Dtypes,
IS_AMD,
tensor_to_scale,
)

random.seed(0)
Expand Down Expand Up @@ -65,9 +65,10 @@ def _test_linear_impl(
emulate: bool,
use_activation_hooks: bool,
fp8_dtypes: Optional[FP8Dtypes] = None,

):
m_fp8 = get_float8_linear(linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes)
m_fp8 = get_float8_linear(
linear_type, m_ref, emulate, use_activation_hooks, fp8_dtypes
)
for _ in range(2):
if linear_requires_sync(linear_type):
sync_float8_amax_and_scale_history(m_fp8)
Expand Down Expand Up @@ -95,7 +96,12 @@ def _test_linear_impl(
]
for buffer_name in amax_buffer_names:
buffer_value = getattr(m_fp8, buffer_name)
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS, E4M3_FNUZ_MAX_POS, E5M2_FNUZ_MAX_POS):
for init_val in (
E4M3_MAX_POS,
E5M2_MAX_POS,
E4M3_FNUZ_MAX_POS,
E5M2_FNUZ_MAX_POS,
):
assert torch.ne(
buffer_value, torch.tensor(init_val)
), f"{buffer_name} not filled, current value {buffer_value}"
Expand Down Expand Up @@ -147,10 +153,16 @@ def test_linear_nobias(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
fp8_dtypes = FP8Dtypes() if not IS_AMD else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
fp8_dtypes = (
FP8Dtypes()
if not IS_AMD
else FP8Dtypes(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)
)
x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes)
self._test_linear_impl(
x, m_ref, linear_type, emulate, use_activation_hooks, fp8_dtypes
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down

0 comments on commit fb3d4ce

Please sign in to comment.