Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
remove clamp_amax=True/False
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy committed Jul 10, 2024
1 parent e4245e4 commit e12c973
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
1 change: 0 additions & 1 deletion float8_experimental/float8_dynamic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def fsdp_pre_all_gather(self, mesh):
self._precomputed_amax,
torch.float8_e4m3fn,
self._precomputed_amax.dtype,
clamp_amax=False,
)
float8_tensor = Float8Tensor.to_float8(
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config
Expand Down
1 change: 1 addition & 0 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_linear import Float8Linear, TensorScalingType

from float8_experimental.float8_utils import (
amax_history_to_scale_stack,
e4m3_dtype,
Expand Down
5 changes: 1 addition & 4 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ def amax_to_scale(
amax: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
clamp_amax: bool = True,
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype in FP8_TYPES:
amax = torch.clamp(amax, min=EPS) if clamp_amax else amax
res = torch.finfo(float8_dtype).max / amax
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

Expand Down

0 comments on commit e12c973

Please sign in to comment.