From ac0afb0507828e3ee1e0e0dc8aec07f4412d87c1 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 11 Jul 2024 14:24:03 -0700 Subject: [PATCH] rename to precompute_float8_dynamic_scale_for_fsdp Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 5 ++- float8_experimental/fsdp_utils.py | 55 ++++++++++++---------------- test/test_fsdp2/test_fsdp2_common.py | 4 +- 3 files changed, 28 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 1ebab81..464e9b1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically. from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from float8_experimental.float8_linear import Float8Linear # create model @@ -58,10 +59,10 @@ for _ in range(N_ITER): y.sum().backward() optimizer.step() - # specific to fsdp2 + float8 with dynamic scaling + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on # this method is optional but is highly recommended for performance # it calcuclates scales for all parameters in a single all-reduce - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) ``` diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index e06ec66..0ade173 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -1,31 +1,28 @@ import math -import warnings from typing import List import torch import torch.nn as nn from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import linear_requires_sync +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_utils import EPS -def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ - Calculate scale for all float8 parameters after optimizer step - It performs a single all-reduce instead of many all-reduces for each parameter - Exmaple usage: + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: model(input).sum().backward() optim.step() - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) """ from torch.distributed._tensor import DTensor if any( - isinstance(m, Float8Linear) - and linear_requires_sync( - m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY - ) + isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED for m in module.modules() ): raise NotImplementedError("Only supports delayed scaling") @@ -38,24 +35,18 @@ def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_scales(weights: List[DTensor]): - # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial - amax_tensor = torch.vstack(max_weights) # Partial - # clamp is dispatched through DTensor - # it will issue a single all-reduce - amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate - if amax_tensor.dtype is torch.float16: - scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - scales = torch.split(scale_tensor, 1) # Replicate - return scales + if not weights: + return - if weights: - scales = compute_scales(weights) - for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor - else: - warnings.warn( - "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" - ) + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 5c7d21a..af57871 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp def check_parity_no_mp( @@ -31,7 +31,7 @@ def check_parity_no_mp( # TODO(future): add amax syncing once delayed scaling is supported optim.step() if model is fsdp_model and precompute: - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1])