Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
cleanup fp8 handling
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 4, 2024
1 parent ba9b5dd commit 4f304cc
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 49 deletions.
2 changes: 1 addition & 1 deletion float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"bias": False,
}
new_mod = cls(**super_kwargs)
new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False)
new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate))
new_mod.backward_config = ScaledMMConfig(emulate, False)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
Expand Down
18 changes: 9 additions & 9 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@
to_fp8_no_autograd,
)

from float8_experimental.float8_utils import (
amax_history_to_scale,
E4M3_MAX_POS,
E5M2_MAX_POS,
tensor_to_amax,
)
from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -142,18 +137,23 @@ def __init__(self, *args, **kwargs):
self.recipe = delayed_scaling_recipe
history_len = self.recipe.history_len

self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS]))
# Default values for history buffers, see above TODO
default_x = torch.finfo(torch.float8_e4m3fn).max
default_w = torch.finfo(torch.float8_e4m3fn).max
default_dl_dy = torch.finfo(torch.float8_e5m2).max

self.register_always_float32_buffer("fp8_amax_x", torch.tensor([default_x]))
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0]))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS]))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([default_w]))
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0]))
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS])
"fp8_amax_dL_dY", torch.tensor([default_dl_dy])
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len)
Expand Down
58 changes: 26 additions & 32 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,17 @@
# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html

# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max

FP16_MAX_POS = torch.finfo(torch.float16).max

# avoid division by zero when calculating scale
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

IS_AMD = torch.cuda.is_available() and torch.version.hip is not None
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}


@torch.no_grad()
Expand All @@ -38,22 +36,16 @@ def amax_to_scale(
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
# to care about this for float32/bfloat16.
if orig_dtype is torch.float16:
res = torch.clamp(res, max=FP16_MAX_POS)
res = torch.clamp(res, max=torch.finfo(torch.float16).max)
scale.copy_(res)
return scale

Expand Down Expand Up @@ -132,18 +124,12 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
is `amax2`, where `amax1 < amax2`. This is common when using delayed
scaling.
"""

if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS)
if float8_dtype in FP8_TYPES:
max_value = torch.finfo(float8_dtype).max
x = x.clamp(min=-max_value, max=max_value)
return x.to(float8_dtype)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)


def compute_error(x: torch.Tensor, y: torch.Tensor):
Expand All @@ -164,11 +150,19 @@ def compute_error(x: torch.Tensor, y: torch.Tensor):
def fp8_tensor_statistics(
tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn
) -> Tuple[int, ...]:
"""Calculate FP8 tensor stats"""
if float8_dtype == torch.float8_e4m3fn:
FP8_MAX = E4M3_MAX_POS
else: # e5m2
FP8_MAX = E5M2_MAX_POS
"""Calculate FP8 tensor stats
Args:
tensor: The tensor to calculate stats for.
float8_dtype: The float8 dtype.
Returns:
A tuple containing the number of zeros and the number of max values.
"""
if float8_dtype in FP8_TYPES:
FP8_MAX = torch.finfo(float8_dtype).max
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype)
num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item()
num_zero = (tensor_orig_type == 0).sum().item()
Expand Down
12 changes: 5 additions & 7 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
from float8_experimental.float8_utils import (
amax_to_scale,
compute_error,
E4M3_MAX_POS,
E5M2_MAX_POS,
FP16_MAX_POS,
fp8_tensor_statistics,
FP8_TYPES,
tensor_to_scale,
)

Expand Down Expand Up @@ -118,9 +116,10 @@ def _test_linear_impl(
"fp8_amax_w",
"fp8_amax_dL_dY",
]
max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES}
for buffer_name in amax_buffer_names:
buffer_value = getattr(m_fp8, buffer_name)
for init_val in (E4M3_MAX_POS, E5M2_MAX_POS):
for init_val in max_float8_pos:
assert torch.ne(
buffer_value, torch.tensor(init_val)
), f"{buffer_name} not filled, current value {buffer_value}"
Expand Down Expand Up @@ -412,9 +411,8 @@ def test_small_amax_float16(self, float8_dtype):
#
# amax + eps >= fp8_max_pos / fp16_max_pos

float8_max_pos = (
E4M3_MAX_POS if float8_dtype is torch.float8_e4m3fn else E5M2_MAX_POS
)
float8_max_pos = torch.finfo(float8_dtype).max
FP16_MAX_POS = torch.finfo(torch.float16).max

target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12)
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
Expand Down

0 comments on commit 4f304cc

Please sign in to comment.