From ba9b5ddfa86123f95d558b50a62799758ad1f440 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 6 Mar 2024 19:33:44 -0800 Subject: [PATCH 1/2] Skipping SAM for now since it hangs --- float8_experimental/float8_utils.py | 87 +++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 18 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 3145f21..d81250e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Literal, Tuple import torch import torch.distributed as dist @@ -14,7 +14,9 @@ # define the e4m3/e5m2 constants E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max +E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max +E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max FP16_MAX_POS = torch.finfo(torch.float16).max @@ -22,14 +24,30 @@ # TODO: align this value with NVIDIA's assumptions (current value is a guess) EPS = 1e-12 +IS_AMD = torch.cuda.is_available() and torch.version.hip is not None + @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): +def amax_to_scale( + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype +): + """Converts the amax value of a tensor to the fp8 scale. + Args: + amax: The amax value of the tensor. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - else: # e5m2 + elif float8_dtype == torch.float8_e4m3fnuz: + res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == torch.float8_e5m2: res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + elif float8_dtype == torch.float8_e5m2fnuz: + res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") # Ensure that the scale is representable in float16, # this helps when amax is small. We are assuming that we don't need @@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype): @torch.no_grad() def amax_history_to_scale( - amax_history, - float8_dtype, - orig_dtype, - history_to_scale_fn_type, + amax_history: torch.Tensor, + float8_dtype: torch.Tensor, + orig_dtype: torch.dtype, + history_to_scale_fn_type: Literal["max"], ): + """Takes in a history of amax values and returns a scale tensor. + Args: + amax_history: A tensor containing the history of amax values. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ if history_to_scale_fn_type == "max": amax = torch.max(amax_history) return amax_to_scale(amax, float8_dtype, orig_dtype) @@ -58,9 +83,15 @@ def amax_history_to_scale_stack( amax_history: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype, - history_to_scale_fn_type: str, + history_to_scale_fn_type: Literal["max"], ) -> torch.Tensor: - """Takes in a stack of amax_history tensors and returns a scale tensor.""" + """Takes in a stack of amax_history tensors and returns a scale tensor. + Args: + amax_history: A 2D tensor containing a stack of amax histories. + float8_dtype: The float8 dtype. + orig_dtype: The original dtype of the tensor. + history_to_scale_fn_type: The type of function to use to convert the history to a scale. + """ if history_to_scale_fn_type == "max": amax_stack = torch.max(amax_history, dim=1).values return amax_to_scale(amax_stack, float8_dtype, orig_dtype) @@ -90,21 +121,41 @@ def tensor_to_scale( return amax_to_scale(amax, float8_dtype, x.dtype) -def to_fp8_saturated(x, float8_dtype: torch.dtype): - # The default behavior in PyTorch for casting to `float8_e4m3fn` - # and `e5m2` is to not saturate. In this context, we should saturate. - # A common case where we want to saturate is when the history of a - # tensor has a maximum value of `amax1`, and the current amax value - # is `amax2`, where `amax1 < amax2`. This is common when using delayed - # scaling. +def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): + """Converts a tensor to a saturated fp8 tensor. + + Note: + The default behavior in PyTorch for casting to `float8_e4m3fn` + and `e5m2` is to not saturate. In this context, we should saturate. + A common case where we want to saturate is when the history of a + tensor has a maximum value of `amax1`, and the current amax value + is `amax2`, where `amax1 < amax2`. This is common when using delayed + scaling. + """ + if float8_dtype == torch.float8_e4m3fn: x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) - else: + elif float8_dtype == torch.float8_e4m3fnuz: + x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS) + elif float8_dtype == torch.float8_e5m2: x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) + elif float8_dtype == torch.float8_e5m2fnuz: + x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS) + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") return x.to(float8_dtype) -def compute_error(x, y): +def compute_error(x: torch.Tensor, y: torch.Tensor): + """Computes the error between two tensors in dB. + + For more details see: + https://en.wikipedia.org/wiki/Signal-to-noise_ratio + + Args: + x: The original tensor. + y: The tensor to compare to the original tensor. + """ Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn) From 4f304cc5ed4b718bbed6b30327346f767a29f721 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 4 Jun 2024 11:31:23 -0400 Subject: [PATCH 2/2] cleanup fp8 handling --- float8_experimental/float8_dynamic_linear.py | 2 +- float8_experimental/float8_linear.py | 18 +++--- float8_experimental/float8_utils.py | 58 +++++++++----------- test/test_base.py | 12 ++-- 4 files changed, 41 insertions(+), 49 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index caeb31c..701ae8a 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -88,7 +88,7 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear": "bias": False, } new_mod = cls(**super_kwargs) - new_mod.forward_config = ScaledMMConfig(emulate, True if not emulate else False) + new_mod.forward_config = ScaledMMConfig(emulate, not bool(emulate)) new_mod.backward_config = ScaledMMConfig(emulate, False) if config.enable_fsdp_fp8_all_gather: new_mod.weight = nn.Parameter( diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 5120f36..568b36f 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -26,12 +26,7 @@ to_fp8_no_autograd, ) -from float8_experimental.float8_utils import ( - amax_history_to_scale, - E4M3_MAX_POS, - E5M2_MAX_POS, - tensor_to_amax, -) +from float8_experimental.float8_utils import amax_history_to_scale, tensor_to_amax def _maybe_initialize_amaxes_scales_for_float8_cast( @@ -142,18 +137,23 @@ def __init__(self, *args, **kwargs): self.recipe = delayed_scaling_recipe history_len = self.recipe.history_len - self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS])) + # Default values for history buffers, see above TODO + default_x = torch.finfo(torch.float8_e4m3fn).max + default_w = torch.finfo(torch.float8_e4m3fn).max + default_dl_dy = torch.finfo(torch.float8_e5m2).max + + self.register_always_float32_buffer("fp8_amax_x", torch.tensor([default_x])) self.register_always_float32_buffer( "fp8_amax_history_x", torch.zeros(history_len) ) self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0])) - self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS])) + self.register_always_float32_buffer("fp8_amax_w", torch.tensor([default_w])) self.register_always_float32_buffer( "fp8_amax_history_w", torch.zeros(history_len) ) self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0])) self.register_always_float32_buffer( - "fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS]) + "fp8_amax_dL_dY", torch.tensor([default_dl_dy]) ) self.register_always_float32_buffer( "fp8_amax_history_dL_dY", torch.zeros(history_len) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index d81250e..8d898ea 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -12,19 +12,17 @@ # Helpful visualizer for debugging (only supports fp32): # https://www.h-schmidt.net/FloatConverter/IEEE754.html -# define the e4m3/e5m2 constants -E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max -E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max -E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max -E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max - -FP16_MAX_POS = torch.finfo(torch.float16).max - # avoid division by zero when calculating scale # TODO: align this value with NVIDIA's assumptions (current value is a guess) EPS = 1e-12 IS_AMD = torch.cuda.is_available() and torch.version.hip is not None +FP8_TYPES = { + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +} @torch.no_grad() @@ -38,14 +36,8 @@ def amax_to_scale( orig_dtype: The original dtype of the tensor. """ scale = torch.empty_like(amax, dtype=torch.float32) - if float8_dtype == torch.float8_e4m3fn: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) - elif float8_dtype == torch.float8_e4m3fnuz: - res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) - elif float8_dtype == torch.float8_e5m2: - res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) - elif float8_dtype == torch.float8_e5m2fnuz: - res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS) + if float8_dtype in FP8_TYPES: + res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") @@ -53,7 +45,7 @@ def amax_to_scale( # this helps when amax is small. We are assuming that we don't need # to care about this for float32/bfloat16. if orig_dtype is torch.float16: - res = torch.clamp(res, max=FP16_MAX_POS) + res = torch.clamp(res, max=torch.finfo(torch.float16).max) scale.copy_(res) return scale @@ -132,18 +124,12 @@ def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype): is `amax2`, where `amax1 < amax2`. This is common when using delayed scaling. """ - - if float8_dtype == torch.float8_e4m3fn: - x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) - elif float8_dtype == torch.float8_e4m3fnuz: - x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS) - elif float8_dtype == torch.float8_e5m2: - x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) - elif float8_dtype == torch.float8_e5m2fnuz: - x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS) + if float8_dtype in FP8_TYPES: + max_value = torch.finfo(float8_dtype).max + x = x.clamp(min=-max_value, max=max_value) + return x.to(float8_dtype) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") - return x.to(float8_dtype) def compute_error(x: torch.Tensor, y: torch.Tensor): @@ -164,11 +150,19 @@ def compute_error(x: torch.Tensor, y: torch.Tensor): def fp8_tensor_statistics( tensor: torch.Tensor, float8_dtype=torch.float8_e4m3fn ) -> Tuple[int, ...]: - """Calculate FP8 tensor stats""" - if float8_dtype == torch.float8_e4m3fn: - FP8_MAX = E4M3_MAX_POS - else: # e5m2 - FP8_MAX = E5M2_MAX_POS + """Calculate FP8 tensor stats + + Args: + tensor: The tensor to calculate stats for. + float8_dtype: The float8 dtype. + + Returns: + A tuple containing the number of zeros and the number of max values. + """ + if float8_dtype in FP8_TYPES: + FP8_MAX = torch.finfo(float8_dtype).max + else: + raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") tensor_orig_type = tensor._data.to(dtype=tensor._orig_dtype) num_max = (torch.abs(tensor_orig_type) == FP8_MAX).sum().item() num_zero = (tensor_orig_type == 0).sum().item() diff --git a/test/test_base.py b/test/test_base.py index d4545ef..6e7a34c 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -31,10 +31,8 @@ from float8_experimental.float8_utils import ( amax_to_scale, compute_error, - E4M3_MAX_POS, - E5M2_MAX_POS, - FP16_MAX_POS, fp8_tensor_statistics, + FP8_TYPES, tensor_to_scale, ) @@ -118,9 +116,10 @@ def _test_linear_impl( "fp8_amax_w", "fp8_amax_dL_dY", ] + max_float8_pos = {torch.finfo(dtype).max for dtype in FP8_TYPES} for buffer_name in amax_buffer_names: buffer_value = getattr(m_fp8, buffer_name) - for init_val in (E4M3_MAX_POS, E5M2_MAX_POS): + for init_val in max_float8_pos: assert torch.ne( buffer_value, torch.tensor(init_val) ), f"{buffer_name} not filled, current value {buffer_value}" @@ -412,9 +411,8 @@ def test_small_amax_float16(self, float8_dtype): # # amax + eps >= fp8_max_pos / fp16_max_pos - float8_max_pos = ( - E4M3_MAX_POS if float8_dtype is torch.float8_e4m3fn else E5M2_MAX_POS - ) + float8_max_pos = torch.finfo(float8_dtype).max + FP16_MAX_POS = torch.finfo(torch.float16).max target_amax = float8_max_pos / (FP16_MAX_POS + 1e-12) x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")