Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Adds utilities for AMD fp8 dtype support, follow up PR to add option to the configs #235

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 69 additions & 18 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from typing import Literal, Tuple

import torch
import torch.distributed as dist
Expand All @@ -14,22 +14,40 @@

# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max

FP16_MAX_POS = torch.finfo(torch.float16).max

# avoid division by zero when calculating scale
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

IS_AMD = torch.cuda.is_available() and torch.version.hip is not None


@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plz avoid code duplication

Suggested change
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype in [torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]:
res = torch.finfo(dtype).max / torch.clamp(amax, min=EPS)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you don't even need ifs there, just assert that float8_dtype is indeed the one

  assert float8_dtype.itemsize == 1 and float8_dtype.is_floating_point

else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
Expand All @@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):

@torch.no_grad()
def amax_history_to_scale(
amax_history,
float8_dtype,
orig_dtype,
history_to_scale_fn_type,
amax_history: torch.Tensor,
float8_dtype: torch.Tensor,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
):
"""Takes in a history of amax values and returns a scale tensor.
Args:
amax_history: A tensor containing the history of amax values.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
return amax_to_scale(amax, float8_dtype, orig_dtype)
Expand All @@ -58,9 +83,15 @@ def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
history_to_scale_fn_type: Literal["max"],
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
"""Takes in a stack of amax_history tensors and returns a scale tensor.
Args:
amax_history: A 2D tensor containing a stack of amax histories.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
Expand Down Expand Up @@ -90,21 +121,41 @@ def tensor_to_scale(
return amax_to_scale(amax, float8_dtype, x.dtype)


def to_fp8_saturated(x, float8_dtype: torch.dtype):
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
"""Converts a tensor to a saturated fp8 tensor.

Note:
The default behavior in PyTorch for casting to `float8_e4m3fn`
and `e5m2` is to not saturate. In this context, we should saturate.
A common case where we want to saturate is when the history of a
tensor has a maximum value of `amax1`, and the current amax value
is `amax2`, where `amax1 < amax2`. This is common when using delayed
scaling.
"""

if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
else:
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)


def compute_error(x, y):
def compute_error(x: torch.Tensor, y: torch.Tensor):
"""Computes the error between two tensors in dB.

For more details see:
https://en.wikipedia.org/wiki/Signal-to-noise_ratio

Args:
x: The original tensor.
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
Expand Down
Loading