From e12c9731d709145f1deaebeb074b5f20d812acbe Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 10 Jul 2024 14:31:46 -0700 Subject: [PATCH] remove clamp_amax=True/False Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 1 - float8_experimental/float8_linear_utils.py | 1 + float8_experimental/float8_utils.py | 5 +---- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index e277806..ecd64fd 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -169,7 +169,6 @@ def fsdp_pre_all_gather(self, mesh): self._precomputed_amax, torch.float8_e4m3fn, self._precomputed_amax.dtype, - clamp_amax=False, ) float8_tensor = Float8Tensor.to_float8( self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index e3af6f8..5d49e65 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -10,6 +10,7 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_linear import Float8Linear, TensorScalingType + from float8_experimental.float8_utils import ( amax_history_to_scale_stack, e4m3_dtype, diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 9f3d243..ad5ffe1 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -37,19 +37,16 @@ def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype, - clamp_amax: bool = True, ): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. - clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype in FP8_TYPES: - amax = torch.clamp(amax, min=EPS) if clamp_amax else amax - res = torch.finfo(float8_dtype).max / amax + res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")