From 9d5595cf947674e8922bf985221291680db9a260 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 20 May 2024 17:00:47 -0700 Subject: [PATCH 01/27] [FSDP2] set vocab_size=32 to avoid must be divisible by 16 error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2_eager.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index d9a0824..98ef92b 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -57,7 +57,12 @@ def init_multi_module(self) -> nn.Module: def init_transformer(self, weight_tying: bool) -> nn.Module: torch.manual_seed(42) args = ModelArgs( - n_layers=3, dim=768, n_heads=12, dropout_p=0.0, weight_tying=weight_tying + n_layers=3, + dim=768, + n_heads=12, + dropout_p=0.0, + weight_tying=weight_tying, + vocab_size=32, ) module = Transformer(args).cuda() self.broadcast_module(module) From e7005c205763d2a58dc4f278e4fc43b938dba361 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 21 May 2024 13:07:09 -0700 Subject: [PATCH 02/27] precast after optimizer.step and dump profiler traces Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 5 ++ float8_experimental/float8_linear_utils.py | 55 ++++++++++++++++- test/test_fsdp2/test_fsdp2_eager.py | 63 ++++++++++++++++---- 3 files changed, 111 insertions(+), 12 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index caeb31c..ef491db 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -151,6 +151,9 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): self._tensor = tensor self._mm_config = mm_config + # Optional cache for pre-computed fp8 data/scale + self._fp8_data: Optional[torch.Tensor] = None + self._fp8_scale: Optional[torch.Tensor] = None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -190,6 +193,8 @@ def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): + if self._fp8_data is not None and self._fp8_scale is not None: + return (self._fp8_data,), (self._fp8_scale,) float8_tensor = cast_to_float8_e4m3fn( self._tensor, self._mm_config, reduce_amax=True ) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 8568b51..b71043f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,16 +5,25 @@ # LICENSE file in the root directory of this source tree. import copy import logging +import warnings from enum import auto, Enum from typing import Callable, List, Optional, Type import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_dynamic_linear import ( + Float8DynamicLinear, + WeightWithDynamicFloat8CastTensor, +) from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import amax_history_to_scale_stack +from float8_experimental.float8_utils import ( + amax_history_to_scale_stack, + E4M3_MAX_POS, + EPS, + to_fp8_saturated, +) from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) @@ -322,3 +331,45 @@ def inner_func(): for child in fp8_layers: # Set a flag to signal amaxes/scales are ready child.amax_and_scale_synced = True + + +def precompute_float8_weights(module: nn.Module) -> None: + from torch.distributed._tensor import DTensor + + if any(isinstance(m, Float8Linear) for m in module.modules()): + raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") + float8_linears: List[Float8DynamicLinear] = [ + m + for m in module.modules() + if isinstance(m, Float8DynamicLinear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + def compute_weights_and_scales(weights: List[DTensor]): + abs_weights = torch._foreach_abs(weights) + # abs_weights = [torch.abs(w) for w in weights] + amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) + amax_tensor = torch.clamp(amax_tensor, EPS) + scales_tensor = E4M3_MAX_POS / amax_tensor + scales = torch.split(scales_tensor, 1) + weights_scaled = torch._foreach_mul(weights, scales) + datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] + # torch._foreach_clamp_min_(weights_scaled, -1 * E4M3_MAX_POS) + # torch._foreach_clamp_max_(weights_scaled, E4M3_MAX_POS) + # datas = [w.to(torch.float8_e4m3fn) for w in weights_scaled] + return datas, scales + + if weights: + datas, scales = compute_weights_and_scales(weights) + # datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights) + for data, scale, float8_linear in zip(datas, scales, float8_linears): + float8_linear.weight._local_tensor._fp8_data = data._local_tensor + float8_linear.weight._local_tensor._fp8_scale = ( + scale._local_tensor.squeeze() + ) + else: + warnings.warn( + "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" + ) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 98ef92b..cc1a5a8 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -11,7 +11,10 @@ Float8DynamicLinear, WeightWithDynamicFloat8CastTensor, ) -from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear +from float8_experimental.float8_linear_utils import ( + precompute_float8_weights, + swap_linear_with_float8_linear, +) from test_fsdp2_common import ( check_parity_bf16_mp, check_parity_no_mp, @@ -57,12 +60,13 @@ def init_multi_module(self) -> nn.Module: def init_transformer(self, weight_tying: bool) -> nn.Module: torch.manual_seed(42) args = ModelArgs( - n_layers=3, - dim=768, - n_heads=12, + n_layers=8, + dim=4096, + n_heads=32, dropout_p=0.0, weight_tying=weight_tying, - vocab_size=32, + vocab_size=4096, + max_seq_len=4096, ) module = Transformer(args).cuda() self.broadcast_module(module) @@ -78,6 +82,39 @@ def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Modul return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) +def profiler(output_dir): + """ + Utility component that wraps around `torch.profiler` to profile model's operators. + See https://pytorch.org/docs/stable/profiler.html for more details. + The schedule for this profiler is wait 100 steps, warmup 5 steps, trace 5 steps + Note: Enabling pytorch profiler may have training speed reduction. + + Args: + enabled (Optional[bool]): Enable pytorch profiler. Default is False. + output_dir (Optional[str]): Tracing file output path. Default is "./torchtune_perf_tracing.json". + + Returns: + ContextManager: pytorch profiler context manager + """ + + def trace_handler(prof) -> None: + prof.export_chrome_trace(output_dir) + + return torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=0, warmup=1, active=2, repeat=1, skip_first=1 + ), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=False, + with_stack=False, + ) + + class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): @property def world_size(self) -> int: @@ -85,7 +122,7 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: + for enable_fsdp_fp8_all_gather in [True]: self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): @@ -106,11 +143,17 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) local_inp = torch.randint( - 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" - ) - check_parity_no_mp( - self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear + 0, ref_module.tok_embeddings.weight.size(0), (4, 512), device="cuda" ) + with profiler( + output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_precast_rank_{torch.distributed.get_rank()}.json" + ) as prof: + for i in range(5): + optim.zero_grad() + module(local_inp).sum().backward() + optim.step() + precompute_float8_weights(module) + prof.step() @skip_if_lt_x_gpu(2) def test_transformer_memory(self): From e0bee10898d994ac41e58a9fc787d1d648ae58c9 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 24 May 2024 11:00:25 -0700 Subject: [PATCH 03/27] precast and preamax unit test Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 24 ++++++++-- float8_experimental/float8_linear_utils.py | 50 ++++++++++++++++---- float8_experimental/float8_utils.py | 12 +++-- test/test_fsdp2/test_fsdp2_eager.py | 26 +++++++--- 4 files changed, 90 insertions(+), 22 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index ef491db..dc92fb0 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -22,7 +22,7 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import tensor_to_scale +from float8_experimental.float8_utils import amax_to_scale, tensor_to_scale from torch._prims_common import suggest_memory_format @@ -144,7 +144,9 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, - pin_memory=tensor.is_pinned(), + # TODO: workaround fake tensor not implementing is.pinned + # pin_memory=tensor.is_pinned(), + pin_memory=False, requires_grad=tensor.requires_grad, ) @@ -154,6 +156,7 @@ def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): # Optional cache for pre-computed fp8 data/scale self._fp8_data: Optional[torch.Tensor] = None self._fp8_scale: Optional[torch.Tensor] = None + self._fp8_amax: Optional[torch.Tensor] = None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -195,9 +198,20 @@ def __repr__(self): def fsdp_pre_all_gather(self, mesh): if self._fp8_data is not None and self._fp8_scale is not None: return (self._fp8_data,), (self._fp8_scale,) - float8_tensor = cast_to_float8_e4m3fn( - self._tensor, self._mm_config, reduce_amax=True - ) + if self._fp8_amax is not None: + scale = amax_to_scale( + self._fp8_amax, + torch.float8_e4m3fn, + self._fp8_amax.dtype, + clamp_amax=False, + ) + float8_tensor = Float8Tensor.to_float8( + self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config + ) + else: + float8_tensor = cast_to_float8_e4m3fn( + self._tensor, self._mm_config, reduce_amax=True + ) return (float8_tensor._data,), (float8_tensor._scale,) def fsdp_post_all_gather( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index b71043f..ce02857 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -333,6 +333,39 @@ def inner_func(): child.amax_and_scale_synced = True +def precompute_float8_amax(module: nn.Module) -> None: + from torch.distributed._tensor import DTensor + + if any(isinstance(m, Float8Linear) for m in module.modules()): + raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") + float8_linears: List[Float8DynamicLinear] = [ + m + for m in module.modules() + if isinstance(m, Float8DynamicLinear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + def compute_amaxes(weights: List[DTensor]): + abs_weights = torch._foreach_abs(weights) # S0 + amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P + amax_tensor = torch.clamp(amax_tensor, EPS) # R + amaxes = torch.split(amax_tensor, 1) # R + return amaxes + + if weights: + # amaxes = compute_amaxes(weights) + # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) + amaxes = torch.compile(compute_amaxes)(weights) + for amax, float8_linear in zip(amaxes, float8_linears): + float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor + else: + warnings.warn( + "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" + ) + + def precompute_float8_weights(module: nn.Module) -> None: from torch.distributed._tensor import DTensor @@ -348,21 +381,22 @@ def precompute_float8_weights(module: nn.Module) -> None: weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] def compute_weights_and_scales(weights: List[DTensor]): - abs_weights = torch._foreach_abs(weights) + abs_weights = torch._foreach_abs(weights) # S0 # abs_weights = [torch.abs(w) for w in weights] - amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) - amax_tensor = torch.clamp(amax_tensor, EPS) - scales_tensor = E4M3_MAX_POS / amax_tensor - scales = torch.split(scales_tensor, 1) - weights_scaled = torch._foreach_mul(weights, scales) - datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] + amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P + amax_tensor = torch.clamp(amax_tensor, EPS) # R + scales_tensor = E4M3_MAX_POS / amax_tensor # R + scales = torch.split(scales_tensor, 1) # R + weights_scaled = torch._foreach_mul(weights, scales) # S0 + datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] # S0 # torch._foreach_clamp_min_(weights_scaled, -1 * E4M3_MAX_POS) # torch._foreach_clamp_max_(weights_scaled, E4M3_MAX_POS) # datas = [w.to(torch.float8_e4m3fn) for w in weights_scaled] return datas, scales if weights: - datas, scales = compute_weights_and_scales(weights) + # datas, scales = compute_weights_and_scales(weights) + datas, scales = torch.compile(compute_weights_and_scales)(weights) # datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights) for data, scale, float8_linear in zip(datas, scales, float8_linears): float8_linear.weight._local_tensor._fp8_data = data._local_tensor diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 3145f21..efea6ae 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -24,12 +24,18 @@ @torch.no_grad() -def amax_to_scale(amax, float8_dtype, orig_dtype): +def amax_to_scale(amax, float8_dtype, orig_dtype, clamp_amax=True): scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: - res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + if clamp_amax: + res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) + else: + res = E4M3_MAX_POS / amax else: # e5m2 - res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + if clamp_amax: + res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) + else: + res = E5M2_MAX_POS / amax # Ensure that the scale is representable in float16, # this helps when amax is small. We are assuming that we don't need diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 09d261b..5afb810 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,7 +1,7 @@ import copy import threading import unittest -from typing import Any, List +from typing import Any, List, Union import torch import torch._dynamo.testing @@ -12,6 +12,7 @@ WeightWithDynamicFloat8CastTensor, ) from float8_experimental.float8_linear_utils import ( + precompute_float8_amax, precompute_float8_weights, swap_linear_with_float8_linear, ) @@ -124,9 +125,14 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): for enable_fsdp_fp8_all_gather in [True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + for pre_compute in [None, "cast", "amax"]: + self._test_transformer_parity_dynamic( + enable_fsdp_fp8_all_gather, pre_compute + ) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity_dynamic( + self, enable_fsdp_fp8_all_gather: bool, pre_compute: Union[str, None] + ): # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -147,13 +153,21 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): 0, ref_module.tok_embeddings.weight.size(0), (4, 512), device="cuda" ) with profiler( - output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_precast_rank_{torch.distributed.get_rank()}.json" + output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_{pre_compute}_rank_{torch.distributed.get_rank()}.json" ) as prof: for i in range(5): optim.zero_grad() - module(local_inp).sum().backward() + loss = module(local_inp).sum() + # if torch.distributed.get_rank() == 0: + # print(f"{pre_compute=} {i=} {loss=}") + loss.backward() optim.step() - precompute_float8_weights(module) + if pre_compute is None: + pass + elif pre_compute == "cast": + precompute_float8_weights(module) + elif pre_compute == "amax": + precompute_float8_amax(module) prof.step() @skip_if_lt_x_gpu(2) From c0ba5a24190e8fece42db5ec2ff23e5aa91bf359 Mon Sep 17 00:00:00 2001 From: willfengg Date: Fri, 24 May 2024 11:04:00 -0700 Subject: [PATCH 04/27] remove duplicate vocab Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2_eager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 5afb810..6366076 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -68,7 +68,6 @@ def init_transformer(self, weight_tying: bool) -> nn.Module: weight_tying=weight_tying, vocab_size=4096, max_seq_len=4096, - vocab_size=32, ) module = Transformer(args).cuda() self.broadcast_module(module) From 8da238e5adea8e6ddc1e841a326362c87107d22d Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 30 May 2024 00:30:01 -0700 Subject: [PATCH 05/27] fused amax Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_linear_utils.py | 45 +++++++++++++++++++--- test/test_fsdp2/test_fsdp2_eager.py | 6 ++- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index ce02857..65d6a1f 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy import logging +import math import warnings from enum import auto, Enum from typing import Callable, List, Optional, Type @@ -349,15 +350,49 @@ def precompute_float8_amax(module: nn.Module) -> None: def compute_amaxes(weights: List[DTensor]): abs_weights = torch._foreach_abs(weights) # S0 - amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P + max_weights = [torch.max(a) for a in abs_weights] + amax_tensor = torch.vstack(max_weights) # P + amax_tensor = torch.clamp(amax_tensor, EPS) # R + amaxes = torch.split(amax_tensor, 1) # R + return amaxes + + if weights: + amaxes = compute_amaxes(weights) + # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) + # amaxes = torch.compile(compute_amaxes)(weights) + for amax, float8_linear in zip(amaxes, float8_linears): + float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor + else: + warnings.warn( + "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" + ) + + +def precompute_float8_amax_fused(module: nn.Module) -> None: + from torch.distributed._tensor import DTensor + + if any(isinstance(m, Float8Linear) for m in module.modules()): + raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") + float8_linears: List[Float8DynamicLinear] = [ + m + for m in module.modules() + if isinstance(m, Float8DynamicLinear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + def compute_amaxes(weights: List[DTensor]): + max_weights = torch._foreach_norm(weights, ord=math.inf) + amax_tensor = torch.vstack(max_weights) amax_tensor = torch.clamp(amax_tensor, EPS) # R amaxes = torch.split(amax_tensor, 1) # R return amaxes if weights: - # amaxes = compute_amaxes(weights) + amaxes = compute_amaxes(weights) # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) - amaxes = torch.compile(compute_amaxes)(weights) + # amaxes = torch.compile(compute_amaxes)(weights) for amax, float8_linear in zip(amaxes, float8_linears): float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor else: @@ -395,8 +430,8 @@ def compute_weights_and_scales(weights: List[DTensor]): return datas, scales if weights: - # datas, scales = compute_weights_and_scales(weights) - datas, scales = torch.compile(compute_weights_and_scales)(weights) + datas, scales = compute_weights_and_scales(weights) + # datas, scales = torch.compile(compute_weights_and_scales)(weights) # datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights) for data, scale, float8_linear in zip(datas, scales, float8_linears): float8_linear.weight._local_tensor._fp8_data = data._local_tensor diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 6366076..9d41bb0 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -13,6 +13,7 @@ ) from float8_experimental.float8_linear_utils import ( precompute_float8_amax, + precompute_float8_amax_fused, precompute_float8_weights, swap_linear_with_float8_linear, ) @@ -124,7 +125,8 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): for enable_fsdp_fp8_all_gather in [True]: - for pre_compute in [None, "cast", "amax"]: + # for pre_compute in [None, "cast", "amax", "amax_fused"]: + for pre_compute in [None, "amax", "amax_fused"]: self._test_transformer_parity_dynamic( enable_fsdp_fp8_all_gather, pre_compute ) @@ -167,6 +169,8 @@ def _test_transformer_parity_dynamic( precompute_float8_weights(module) elif pre_compute == "amax": precompute_float8_amax(module) + elif pre_compute == "amax_fused": + precompute_float8_amax_fused(module) prof.step() @skip_if_lt_x_gpu(2) From aefa21b539da7f563d81d1cf9f075b728312d2ea Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 5 Jun 2024 17:31:27 -0700 Subject: [PATCH 06/27] use FP8_TYPES and max Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_linear_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 65d6a1f..d40afa4 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -21,7 +21,6 @@ from float8_experimental.float8_utils import ( amax_history_to_scale_stack, - E4M3_MAX_POS, EPS, to_fp8_saturated, ) @@ -30,6 +29,8 @@ log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) +E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + class LinearType(Enum): DELAYED = auto() From d4a1db7c0c6c67a0b52fb53234a871d7bf68707f Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 5 Jun 2024 23:04:47 -0700 Subject: [PATCH 07/27] commit all changes before cleaning Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2_common.py | 20 +++++++++- test/test_fsdp2/test_fsdp2_eager.py | 60 +++++++++++++++++----------- 2 files changed, 54 insertions(+), 26 deletions(-) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 4f20fd5..da625ba 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,13 +1,19 @@ import contextlib -from typing import List, Type +from typing import List, Optional, Type, Union import float8_experimental.config as config import torch import torch.distributed as dist import torch.nn as nn +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history +from float8_experimental.float8_linear_utils import ( + precompute_float8_amax, + precompute_float8_amax_fused, + precompute_float8_weights, + sync_float8_amax_and_scale_history, +) def check_parity_no_mp( @@ -18,6 +24,7 @@ def check_parity_no_mp( fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, module_cls: Type, + pre_compute: Optional[Union[str, None]], ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -32,6 +39,15 @@ def check_parity_no_mp( if module_cls is Float8Linear: sync_float8_amax_and_scale_history(model) optim.step() + if module_cls is Float8DynamicLinear and model is fsdp_model: + if pre_compute is None: + pass + elif pre_compute == "cast": + precompute_float8_weights(model) + elif pre_compute == "amax": + precompute_float8_amax(model) + elif pre_compute == "amax_fused": + precompute_float8_amax_fused(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 9d41bb0..1cc3e80 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -12,9 +12,9 @@ WeightWithDynamicFloat8CastTensor, ) from float8_experimental.float8_linear_utils import ( - precompute_float8_amax, - precompute_float8_amax_fused, - precompute_float8_weights, + # precompute_float8_amax, + # precompute_float8_amax_fused, + # precompute_float8_weights, swap_linear_with_float8_linear, ) from test_fsdp2_common import ( @@ -126,7 +126,9 @@ def world_size(self) -> int: def test_transformer_parity_dynamic(self): for enable_fsdp_fp8_all_gather in [True]: # for pre_compute in [None, "cast", "amax", "amax_fused"]: - for pre_compute in [None, "amax", "amax_fused"]: + # for pre_compute in [None, "amax", "amax_fused"]: + for pre_compute in [None, "cast", "amax_fused"]: + # for pre_compute in ["amax_fused"]: self._test_transformer_parity_dynamic( enable_fsdp_fp8_all_gather, pre_compute ) @@ -151,27 +153,37 @@ def _test_transformer_parity_dynamic( ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) local_inp = torch.randint( - 0, ref_module.tok_embeddings.weight.size(0), (4, 512), device="cuda" + 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - with profiler( - output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_{pre_compute}_rank_{torch.distributed.get_rank()}.json" - ) as prof: - for i in range(5): - optim.zero_grad() - loss = module(local_inp).sum() - # if torch.distributed.get_rank() == 0: - # print(f"{pre_compute=} {i=} {loss=}") - loss.backward() - optim.step() - if pre_compute is None: - pass - elif pre_compute == "cast": - precompute_float8_weights(module) - elif pre_compute == "amax": - precompute_float8_amax(module) - elif pre_compute == "amax_fused": - precompute_float8_amax_fused(module) - prof.step() + check_parity_no_mp( + self, + ref_module, + ref_optim, + module, + optim, + local_inp, + Float8DynamicLinear, + pre_compute, + ) + # with profiler( + # output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_{pre_compute}_rank_{torch.distributed.get_rank()}.json" + # ) as prof: + # for i in range(5): + # optim.zero_grad() + # loss = module(local_inp).sum() + # # if torch.distributed.get_rank() == 0: + # # print(f"{pre_compute=} {i=} {loss=}") + # loss.backward() + # optim.step() + # if pre_compute is None: + # pass + # elif pre_compute == "cast": + # precompute_float8_weights(module) + # elif pre_compute == "amax": + # precompute_float8_amax(module) + # elif pre_compute == "amax_fused": + # precompute_float8_amax_fused(module) + # prof.step() @skip_if_lt_x_gpu(2) def test_transformer_memory(self): From d36e79b95f7bdc34b55abd7b96c8c52807e61c8d Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 00:02:31 -0700 Subject: [PATCH 08/27] pre_compute and flatten / unflatten Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 15 ++-- float8_experimental/float8_linear_utils.py | 85 +------------------- test/test_fsdp2/test_fsdp2_common.py | 23 +++--- test/test_fsdp2/test_fsdp2_eager.py | 30 ++++--- 4 files changed, 31 insertions(+), 122 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 2b224eb..d39985f 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -153,10 +153,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): self._tensor = tensor self._mm_config = mm_config - # Optional cache for pre-computed fp8 data/scale - self._fp8_data: Optional[torch.Tensor] = None - self._fp8_scale: Optional[torch.Tensor] = None - self._fp8_amax: Optional[torch.Tensor] = None + self._pre_computed_amax = None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -185,7 +182,7 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + return ["_tensor", "_pre_computed_amax"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): @@ -196,13 +193,11 @@ def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - if self._fp8_data is not None and self._fp8_scale is not None: - return (self._fp8_data,), (self._fp8_scale,) - if self._fp8_amax is not None: + if self._pre_computed_amax is not None: scale = amax_to_scale( - self._fp8_amax, + self._pre_computed_amax, torch.float8_e4m3fn, - self._fp8_amax.dtype, + self._pre_computed_amax.dtype, clamp_amax=False, ) float8_tensor = Float8Tensor.to_float8( diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index d40afa4..5b539a3 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -19,11 +19,7 @@ ) from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import ( - amax_history_to_scale_stack, - EPS, - to_fp8_saturated, -) +from float8_experimental.float8_utils import amax_history_to_scale_stack, EPS from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) @@ -349,40 +345,6 @@ def precompute_float8_amax(module: nn.Module) -> None: ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_amaxes(weights: List[DTensor]): - abs_weights = torch._foreach_abs(weights) # S0 - max_weights = [torch.max(a) for a in abs_weights] - amax_tensor = torch.vstack(max_weights) # P - amax_tensor = torch.clamp(amax_tensor, EPS) # R - amaxes = torch.split(amax_tensor, 1) # R - return amaxes - - if weights: - amaxes = compute_amaxes(weights) - # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) - # amaxes = torch.compile(compute_amaxes)(weights) - for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor - else: - warnings.warn( - "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" - ) - - -def precompute_float8_amax_fused(module: nn.Module) -> None: - from torch.distributed._tensor import DTensor - - if any(isinstance(m, Float8Linear) for m in module.modules()): - raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") - float8_linears: List[Float8DynamicLinear] = [ - m - for m in module.modules() - if isinstance(m, Float8DynamicLinear) - and isinstance(m.weight, DTensor) - and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) - ] - weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_amaxes(weights: List[DTensor]): max_weights = torch._foreach_norm(weights, ord=math.inf) amax_tensor = torch.vstack(max_weights) @@ -395,50 +357,7 @@ def compute_amaxes(weights: List[DTensor]): # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) # amaxes = torch.compile(compute_amaxes)(weights) for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor - else: - warnings.warn( - "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" - ) - - -def precompute_float8_weights(module: nn.Module) -> None: - from torch.distributed._tensor import DTensor - - if any(isinstance(m, Float8Linear) for m in module.modules()): - raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") - float8_linears: List[Float8DynamicLinear] = [ - m - for m in module.modules() - if isinstance(m, Float8DynamicLinear) - and isinstance(m.weight, DTensor) - and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) - ] - weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - - def compute_weights_and_scales(weights: List[DTensor]): - abs_weights = torch._foreach_abs(weights) # S0 - # abs_weights = [torch.abs(w) for w in weights] - amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P - amax_tensor = torch.clamp(amax_tensor, EPS) # R - scales_tensor = E4M3_MAX_POS / amax_tensor # R - scales = torch.split(scales_tensor, 1) # R - weights_scaled = torch._foreach_mul(weights, scales) # S0 - datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] # S0 - # torch._foreach_clamp_min_(weights_scaled, -1 * E4M3_MAX_POS) - # torch._foreach_clamp_max_(weights_scaled, E4M3_MAX_POS) - # datas = [w.to(torch.float8_e4m3fn) for w in weights_scaled] - return datas, scales - - if weights: - datas, scales = compute_weights_and_scales(weights) - # datas, scales = torch.compile(compute_weights_and_scales)(weights) - # datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights) - for data, scale, float8_linear in zip(datas, scales, float8_linears): - float8_linear.weight._local_tensor._fp8_data = data._local_tensor - float8_linear.weight._local_tensor._fp8_scale = ( - scale._local_tensor.squeeze() - ) + float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor else: warnings.warn( "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index da625ba..3cb4d36 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,5 +1,7 @@ import contextlib -from typing import List, Optional, Type, Union + +# from typing import List, Optional, Type, Union +from typing import List, Type import float8_experimental.config as config @@ -10,8 +12,6 @@ from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_utils import ( precompute_float8_amax, - precompute_float8_amax_fused, - precompute_float8_weights, sync_float8_amax_and_scale_history, ) @@ -24,7 +24,7 @@ def check_parity_no_mp( fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, module_cls: Type, - pre_compute: Optional[Union[str, None]], + pre_compute: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -39,15 +39,12 @@ def check_parity_no_mp( if module_cls is Float8Linear: sync_float8_amax_and_scale_history(model) optim.step() - if module_cls is Float8DynamicLinear and model is fsdp_model: - if pre_compute is None: - pass - elif pre_compute == "cast": - precompute_float8_weights(model) - elif pre_compute == "amax": - precompute_float8_amax(model) - elif pre_compute == "amax_fused": - precompute_float8_amax_fused(model) + if ( + model is fsdp_model + and module_cls is Float8DynamicLinear + and pre_compute + ): + precompute_float8_amax(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 1cc3e80..ada751a 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,7 +1,7 @@ import copy import threading import unittest -from typing import Any, List, Union +from typing import Any, List import torch import torch._dynamo.testing @@ -11,12 +11,7 @@ Float8DynamicLinear, WeightWithDynamicFloat8CastTensor, ) -from float8_experimental.float8_linear_utils import ( - # precompute_float8_amax, - # precompute_float8_amax_fused, - # precompute_float8_weights, - swap_linear_with_float8_linear, -) +from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( check_parity_bf16_mp, check_parity_no_mp, @@ -124,18 +119,21 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [True]: - # for pre_compute in [None, "cast", "amax", "amax_fused"]: - # for pre_compute in [None, "amax", "amax_fused"]: - for pre_compute in [None, "cast", "amax_fused"]: - # for pre_compute in ["amax_fused"]: - self._test_transformer_parity_dynamic( - enable_fsdp_fp8_all_gather, pre_compute - ) + self.run_subtests( + { + "enable_fsdp_fp8_all_gather": [False, True], + "pre_compute": [False, True], + }, + self._test_transformer_parity_dynamic, + ) def _test_transformer_parity_dynamic( - self, enable_fsdp_fp8_all_gather: bool, pre_compute: Union[str, None] + self, + enable_fsdp_fp8_all_gather: bool, + pre_compute: bool, ): + if not enable_fsdp_fp8_all_gather and pre_compute: + return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to From 6f244a2a3adf9a7ed70f233ad62a2a1941747010 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 00:06:23 -0700 Subject: [PATCH 09/27] remove unused constant Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_linear_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 5b539a3..a3f5283 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -25,8 +25,6 @@ log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) -E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max - class LinearType(Enum): DELAYED = auto() From dc5eab02042686a7e84016e9518e7af7a474f409 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 00:57:39 -0700 Subject: [PATCH 10/27] torch.compile works Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 2 +- float8_experimental/float8_linear_utils.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index d39985f..5a36140 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -182,7 +182,7 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor", "_pre_computed_amax"], self._mm_config + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index a3f5283..0afd114 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import copy import logging -import math + +# import math import warnings from enum import auto, Enum from typing import Callable, List, Optional, Type @@ -344,15 +345,17 @@ def precompute_float8_amax(module: nn.Module) -> None: weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] def compute_amaxes(weights: List[DTensor]): - max_weights = torch._foreach_norm(weights, ord=math.inf) + abs_weights = torch._foreach_abs(weights) # S0 + max_weights = [torch.max(a) for a in abs_weights] + # max_weights = torch._foreach_norm(weights, ord=math.inf) amax_tensor = torch.vstack(max_weights) amax_tensor = torch.clamp(amax_tensor, EPS) # R amaxes = torch.split(amax_tensor, 1) # R return amaxes if weights: - amaxes = compute_amaxes(weights) - # amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) + # amaxes = compute_amaxes(weights) + amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) # amaxes = torch.compile(compute_amaxes)(weights) for amax, float8_linear in zip(amaxes, float8_linears): float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor From 546e9790c15eddbaadb1648f0db0c65e1820bee3 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 01:03:59 -0700 Subject: [PATCH 11/27] eager ready Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 4 +- float8_experimental/float8_linear_utils.py | 10 +--- test/test_fsdp2/test_fsdp2_common.py | 1 - test/test_fsdp2/test_fsdp2_eager.py | 61 ++------------------ 4 files changed, 8 insertions(+), 68 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 5a36140..1ec8b48 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -144,9 +144,7 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, - # TODO: workaround fake tensor not implementing is.pinned - # pin_memory=tensor.is_pinned(), - pin_memory=False, + pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 0afd114..71a786c 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -6,7 +6,7 @@ import copy import logging -# import math +import math import warnings from enum import auto, Enum from typing import Callable, List, Optional, Type @@ -345,18 +345,14 @@ def precompute_float8_amax(module: nn.Module) -> None: weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] def compute_amaxes(weights: List[DTensor]): - abs_weights = torch._foreach_abs(weights) # S0 - max_weights = [torch.max(a) for a in abs_weights] - # max_weights = torch._foreach_norm(weights, ord=math.inf) + max_weights = torch._foreach_norm(weights, ord=math.inf) amax_tensor = torch.vstack(max_weights) amax_tensor = torch.clamp(amax_tensor, EPS) # R amaxes = torch.split(amax_tensor, 1) # R return amaxes if weights: - # amaxes = compute_amaxes(weights) - amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) - # amaxes = torch.compile(compute_amaxes)(weights) + amaxes = compute_amaxes(weights) for amax, float8_linear in zip(amaxes, float8_linears): float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor else: diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 3cb4d36..18a36fa 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,6 +1,5 @@ import contextlib -# from typing import List, Optional, Type, Union from typing import List, Type import float8_experimental.config as config diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index ada751a..3c16fad 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -57,13 +57,12 @@ def init_multi_module(self) -> nn.Module: def init_transformer(self, weight_tying: bool) -> nn.Module: torch.manual_seed(42) args = ModelArgs( - n_layers=8, - dim=4096, - n_heads=32, + n_layers=3, + dim=768, + n_heads=12, dropout_p=0.0, weight_tying=weight_tying, - vocab_size=4096, - max_seq_len=4096, + vocab_size=32, ) module = Transformer(args).cuda() self.broadcast_module(module) @@ -79,39 +78,6 @@ def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Modul return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) -def profiler(output_dir): - """ - Utility component that wraps around `torch.profiler` to profile model's operators. - See https://pytorch.org/docs/stable/profiler.html for more details. - The schedule for this profiler is wait 100 steps, warmup 5 steps, trace 5 steps - Note: Enabling pytorch profiler may have training speed reduction. - - Args: - enabled (Optional[bool]): Enable pytorch profiler. Default is False. - output_dir (Optional[str]): Tracing file output path. Default is "./torchtune_perf_tracing.json". - - Returns: - ContextManager: pytorch profiler context manager - """ - - def trace_handler(prof) -> None: - prof.export_chrome_trace(output_dir) - - return torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - schedule=torch.profiler.schedule( - wait=0, warmup=1, active=2, repeat=1, skip_first=1 - ), - on_trace_ready=trace_handler, - record_shapes=True, - profile_memory=False, - with_stack=False, - ) - - class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): @property def world_size(self) -> int: @@ -163,25 +129,6 @@ def _test_transformer_parity_dynamic( Float8DynamicLinear, pre_compute, ) - # with profiler( - # output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_{pre_compute}_rank_{torch.distributed.get_rank()}.json" - # ) as prof: - # for i in range(5): - # optim.zero_grad() - # loss = module(local_inp).sum() - # # if torch.distributed.get_rank() == 0: - # # print(f"{pre_compute=} {i=} {loss=}") - # loss.backward() - # optim.step() - # if pre_compute is None: - # pass - # elif pre_compute == "cast": - # precompute_float8_weights(module) - # elif pre_compute == "amax": - # precompute_float8_amax(module) - # elif pre_compute == "amax_fused": - # precompute_float8_amax_fused(module) - # prof.step() @skip_if_lt_x_gpu(2) def test_transformer_memory(self): From 229ede60874e55ae6d6057cdeece08145a429ec0 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 01:09:50 -0700 Subject: [PATCH 12/27] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 18a36fa..aa44204 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -1,5 +1,4 @@ import contextlib - from typing import List, Type import float8_experimental.config as config From d5b3ff6a1910478d407d0a0a72bf0dff619c25cc Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 6 Jun 2024 01:18:17 -0700 Subject: [PATCH 13/27] linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 6b6520b..035e403 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -27,7 +27,11 @@ @torch.no_grad() def amax_to_scale( - amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype, clamp_amax: bool = True): + amax: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + clamp_amax: bool = True, +): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. From 4f05e0411a6dda31e8090515e3f345cd2e16fe62 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 24 Jun 2024 23:39:24 -0700 Subject: [PATCH 14/27] flatten tensor Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 32 +++++++++++++++----- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 1ec8b48..6294e17 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -134,7 +134,12 @@ def cast_to_float8_e5m2_bw( class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + amax: Optional[torch.Tensor] = None, + ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -144,14 +149,20 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, - pin_memory=tensor.is_pinned(), + pin_memory=False, + # pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + amax: Optional[torch.Tensor] = None, + ): self._tensor = tensor self._mm_config = mm_config - self._pre_computed_amax = None + self._pre_computed_amax = amax @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -180,15 +191,22 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + if self._pre_computed_amax: + return ["_tensor", "_pre_computed_amax"], self._mm_config + else: + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_pre_computed_amax", None), + ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, pre_computed_amax={self._pre_computed_amax})" def fsdp_pre_all_gather(self, mesh): if self._pre_computed_amax is not None: From 3de59af25046c7fd101f0203f8209ccf0e711039 Mon Sep 17 00:00:00 2001 From: willfengg Date: Mon, 8 Jul 2024 11:37:25 -0700 Subject: [PATCH 15/27] commit all changes for review before rebasing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_linear.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_dynamic_linear.py b/float8_experimental/float8_dynamic_linear.py index 6294e17..ae860d1 100644 --- a/float8_experimental/float8_dynamic_linear.py +++ b/float8_experimental/float8_dynamic_linear.py @@ -149,8 +149,7 @@ def __new__( dtype=tensor.dtype, layout=tensor.layout, device=tensor.device, - pin_memory=False, - # pin_memory=tensor.is_pinned(), + pin_memory=tensor.is_pinned(), requires_grad=tensor.requires_grad, ) @@ -206,7 +205,7 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): ) def __repr__(self): - return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config}, pre_computed_amax={self._pre_computed_amax})" + return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): if self._pre_computed_amax is not None: From 562424c7c365748a1fdc5c6386ea1640e9552c95 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 15:55:25 -0700 Subject: [PATCH 16/27] move precompute to fsdp_utils.py Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_linear_utils.py | 44 ---------------------- test/test_fsdp2/test_fsdp2_common.py | 4 +- 2 files changed, 2 insertions(+), 46 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index bfc846c..e3af6f8 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -4,21 +4,16 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import logging - -import math -import warnings from typing import Callable, List, Optional import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_utils import ( amax_history_to_scale_stack, e4m3_dtype, e5m2_dtype, - EPS, ) from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor @@ -352,42 +347,3 @@ def inner_func(): for child in fp8_layers: # Set a flag to signal amaxes/scales are ready child.amax_and_scale_synced = True - - -def precompute_float8_amax(module: nn.Module) -> None: - from torch.distributed._tensor import DTensor - - if any( - isinstance(m, Float8Linear) - and not ( - m.scaling_type_x == TensorScalingType.DYNAMIC - and m.scaling_type_w == TensorScalingType.DYNAMIC - and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC - ) - for m in module.modules() - ): - raise NotImplementedError("Only supports delayed scaling") - float8_linears: List[Float8Linear] = [ - m - for m in module.modules() - if isinstance(m, Float8Linear) - and isinstance(m.weight, DTensor) - and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) - ] - weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - - def compute_amaxes(weights: List[DTensor]): - max_weights = torch._foreach_norm(weights, ord=math.inf) - amax_tensor = torch.vstack(max_weights) - amax_tensor = torch.clamp(amax_tensor, EPS) # R - amaxes = torch.split(amax_tensor, 1) # R - return amaxes - - if weights: - amaxes = compute_amaxes(weights) - for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor - else: - warnings.warn( - "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" - ) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index c0045df..9290bab 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.float8_linear_utils import precompute_float8_amax +from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp def check_parity_no_mp( @@ -31,7 +31,7 @@ def check_parity_no_mp( # TODO(future): add amax syncing once delayed scaling is supported optim.step() if model is fsdp_model and pre_compute: - precompute_float8_amax(model) + precompute_float8_amax_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) From 75e0e4573e757394b235e7b11a045b228944e9df Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 16:15:10 -0700 Subject: [PATCH 17/27] simplify amax calc Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 14c9f01..9f3d243 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -48,10 +48,8 @@ def amax_to_scale( """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype in FP8_TYPES: - if clamp_amax: - res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) - else: - res = torch.finfo(float8_dtype).max / amax + amax = torch.clamp(amax, min=EPS) if clamp_amax else amax + res = torch.finfo(float8_dtype).max / amax else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") From fe95f8b6586c6d94d1635fb30a355f1249348b6e Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 16:19:12 -0700 Subject: [PATCH 18/27] explain _pre_computed_amax Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 48f1031..325db82 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -114,6 +114,9 @@ def __init__( ): self._tensor = tensor self._mm_config = mm_config + # for dynamic scaling + # `precompute_float8_amax_for_fsdp` calculates amax + # for all float8 parameters after optimizer step self._pre_computed_amax = amax @classmethod From 1cbaa13aceb81194ceeb9c01f585ec93a5edc19d Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 16:28:43 -0700 Subject: [PATCH 19/27] fix linter Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/fsdp_utils.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 float8_experimental/fsdp_utils.py diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py new file mode 100644 index 0000000..aab37ba --- /dev/null +++ b/float8_experimental/fsdp_utils.py @@ -0,0 +1,48 @@ +import math +import warnings +from typing import List + +import torch +import torch.nn as nn +from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear_utils import linear_requires_sync +from float8_experimental.float8_utils import EPS + + +def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: + from torch.distributed._tensor import DTensor + + if any( + isinstance(m, Float8Linear) + and linear_requires_sync( + m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY + ) + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + float8_linears: List[Float8Linear] = [ + m + for m in module.modules() + if isinstance(m, Float8Linear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + def compute_amaxes(weights: List[DTensor]): + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) + amax_tensor = torch.vstack(max_weights) + amax_tensor = torch.clamp(amax_tensor, EPS) # R + amaxes = torch.split(amax_tensor, 1) # R + return amaxes + + if weights: + amaxes = compute_amaxes(weights) + for amax, float8_linear in zip(amaxes, float8_linears): + float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor + else: + warnings.warn( + "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" + ) From fe2e0a0e78ac9993c334603610b13daf8abd4330 Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 16:35:38 -0700 Subject: [PATCH 20/27] document precompute_float8_amax_for_fsdp Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/fsdp_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index aab37ba..096b13f 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -11,6 +11,14 @@ def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: + """ + Calculate amax for all float8 parameters after optimizer step + It performs a single all-reduce instead of many all-reduces for each parameter + Exmaple usage: + model(input).sum().backward() + optim.step() + precompute_float8_amax_for_fsdp(model) + """ from torch.distributed._tensor import DTensor if any( @@ -32,10 +40,12 @@ def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: def compute_amaxes(weights: List[DTensor]): # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) - amax_tensor = torch.vstack(max_weights) - amax_tensor = torch.clamp(amax_tensor, EPS) # R - amaxes = torch.split(amax_tensor, 1) # R + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + amaxes = torch.split(amax_tensor, 1) # Replicate return amaxes if weights: From e4eaa2a0f0b9bd8a889d96cbd01ad8afd204987a Mon Sep 17 00:00:00 2001 From: willfengg Date: Tue, 9 Jul 2024 16:48:16 -0700 Subject: [PATCH 21/27] rename pre_compute to precompute Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 14 +++++++------- float8_experimental/fsdp_utils.py | 2 +- test/test_fsdp2/test_fsdp2_common.py | 4 ++-- test/test_fsdp2/test_fsdp2_eager.py | 8 ++++---- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 325db82..e277806 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -117,7 +117,7 @@ def __init__( # for dynamic scaling # `precompute_float8_amax_for_fsdp` calculates amax # for all float8 parameters after optimizer step - self._pre_computed_amax = amax + self._precomputed_amax = amax @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -146,8 +146,8 @@ def unwrap(t): ) def __tensor_flatten__(self): - if self._pre_computed_amax: - return ["_tensor", "_pre_computed_amax"], self._mm_config + if self._precomputed_amax: + return ["_tensor", "_precomputed_amax"], self._mm_config else: return ["_tensor"], self._mm_config @@ -157,18 +157,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], mm_config, - getattr(inner_tensors, "_pre_computed_amax", None), + getattr(inner_tensors, "_precomputed_amax", None), ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - if self._pre_computed_amax is not None: + if self._precomputed_amax is not None: scale = amax_to_scale( - self._pre_computed_amax, + self._precomputed_amax, torch.float8_e4m3fn, - self._pre_computed_amax.dtype, + self._precomputed_amax.dtype, clamp_amax=False, ) float8_tensor = Float8Tensor.to_float8( diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index 096b13f..bb21e88 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -51,7 +51,7 @@ def compute_amaxes(weights: List[DTensor]): if weights: amaxes = compute_amaxes(weights) for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._pre_computed_amax = amax._local_tensor + float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor else: warnings.warn( "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 9290bab..55d7681 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -16,7 +16,7 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, - pre_compute: bool = False, + precompute: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -30,7 +30,7 @@ def check_parity_no_mp( param.grad.div_(dist.get_world_size()) # TODO(future): add amax syncing once delayed scaling is supported optim.step() - if model is fsdp_model and pre_compute: + if model is fsdp_model and precompute: precompute_float8_amax_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 48be9b2..bdbc878 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -89,7 +89,7 @@ def test_transformer_parity_dynamic(self): self.run_subtests( { "enable_fsdp_fp8_all_gather": [False, True], - "pre_compute": [False, True], + "precompute": [False, True], }, self._test_transformer_parity_dynamic, ) @@ -97,9 +97,9 @@ def test_transformer_parity_dynamic(self): def _test_transformer_parity_dynamic( self, enable_fsdp_fp8_all_gather: bool, - pre_compute: bool, + precompute: bool, ): - if not enable_fsdp_fp8_all_gather and pre_compute: + if not enable_fsdp_fp8_all_gather and precompute: return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the @@ -121,7 +121,7 @@ def _test_transformer_parity_dynamic( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) check_parity_no_mp( - self, ref_module, ref_optim, module, optim, local_inp, pre_compute + self, ref_module, ref_optim, module, optim, local_inp, precompute ) @skip_if_lt_x_gpu(2) From e12c9731d709145f1deaebeb074b5f20d812acbe Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 10 Jul 2024 14:31:46 -0700 Subject: [PATCH 22/27] remove clamp_amax=True/False Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 1 - float8_experimental/float8_linear_utils.py | 1 + float8_experimental/float8_utils.py | 5 +---- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index e277806..ecd64fd 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -169,7 +169,6 @@ def fsdp_pre_all_gather(self, mesh): self._precomputed_amax, torch.float8_e4m3fn, self._precomputed_amax.dtype, - clamp_amax=False, ) float8_tensor = Float8Tensor.to_float8( self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index e3af6f8..5d49e65 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -10,6 +10,7 @@ import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_linear import Float8Linear, TensorScalingType + from float8_experimental.float8_utils import ( amax_history_to_scale_stack, e4m3_dtype, diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 9f3d243..ad5ffe1 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -37,19 +37,16 @@ def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype, - clamp_amax: bool = True, ): """Converts the amax value of a tensor to the fp8 scale. Args: amax: The amax value of the tensor. float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. - clamp_amax: default is True. False for FSDP fp8 all-gather since FSDP applied `torch.clamp` during pre-compute after optimizer.step """ scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype in FP8_TYPES: - amax = torch.clamp(amax, min=EPS) if clamp_amax else amax - res = torch.finfo(float8_dtype).max / amax + res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") From 9ef67fb1aa676ecf9cad80becf293abe8686867f Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 10 Jul 2024 16:09:37 -0700 Subject: [PATCH 23/27] precompute scale Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 33 ++++++++------------- float8_experimental/float8_utils.py | 4 +-- float8_experimental/fsdp_utils.py | 21 +++++++------ 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index ecd64fd..3cdbaf9 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -19,12 +19,7 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import ( - amax_to_scale, - e4m3_dtype, - e5m2_dtype, - tensor_to_scale, -) +from float8_experimental.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_scale from torch._prims_common import suggest_memory_format @@ -91,7 +86,7 @@ def __new__( cls, tensor: torch.Tensor, mm_config: ScaledMMConfig, - amax: Optional[torch.Tensor] = None, + precomputed_scale: Optional[torch.Tensor] = None, ): return torch.Tensor._make_wrapper_subclass( cls, @@ -110,14 +105,14 @@ def __init__( self, tensor: torch.Tensor, mm_config: ScaledMMConfig, - amax: Optional[torch.Tensor] = None, + precomputed_scale: Optional[torch.Tensor] = None, ): self._tensor = tensor self._mm_config = mm_config # for dynamic scaling - # `precompute_float8_amax_for_fsdp` calculates amax + # `precompute_float8_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step - self._precomputed_amax = amax + self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -146,8 +141,8 @@ def unwrap(t): ) def __tensor_flatten__(self): - if self._precomputed_amax: - return ["_tensor", "_precomputed_amax"], self._mm_config + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config else: return ["_tensor"], self._mm_config @@ -157,21 +152,19 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): return WeightWithDynamicFloat8CastTensor( inner_tensors["_tensor"], mm_config, - getattr(inner_tensors, "_precomputed_amax", None), + getattr(inner_tensors, "_precomputed_scale", None), ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - if self._precomputed_amax is not None: - scale = amax_to_scale( - self._precomputed_amax, - torch.float8_e4m3fn, - self._precomputed_amax.dtype, - ) + if self._precomputed_scale is not None: float8_tensor = Float8Tensor.to_float8( - self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, ) else: float8_tensor = cast_to_float8_e4m3_dynamic( diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index ad5ffe1..2be568e 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -34,9 +34,7 @@ @torch.no_grad() def amax_to_scale( - amax: torch.Tensor, - float8_dtype: torch.dtype, - orig_dtype: torch.dtype, + amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype ): """Converts the amax value of a tensor to the fp8 scale. Args: diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index bb21e88..e06ec66 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -10,14 +10,14 @@ from float8_experimental.float8_utils import EPS -def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: +def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: """ - Calculate amax for all float8 parameters after optimizer step + Calculate scale for all float8 parameters after optimizer step It performs a single all-reduce instead of many all-reduces for each parameter Exmaple usage: model(input).sum().backward() optim.step() - precompute_float8_amax_for_fsdp(model) + precompute_float8_scale_for_fsdp(model) """ from torch.distributed._tensor import DTensor @@ -38,20 +38,23 @@ def precompute_float8_amax_for_fsdp(module: nn.Module) -> None: ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_amaxes(weights: List[DTensor]): + def compute_scales(weights: List[DTensor]): # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial amax_tensor = torch.vstack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - amaxes = torch.split(amax_tensor, 1) # Replicate - return amaxes + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + return scales if weights: - amaxes = compute_amaxes(weights) - for amax, float8_linear in zip(amaxes, float8_linears): - float8_linear.weight._local_tensor._precomputed_amax = amax._local_tensor + scales = compute_scales(weights) + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor else: warnings.warn( "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" From fa2f08a3f49aa31c1d1e12b939e9bc40293e73d9 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 10 Jul 2024 16:11:24 -0700 Subject: [PATCH 24/27] unit test for precomputing scales Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/test_fsdp2/test_fsdp2_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 55d7681..5c7d21a 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp +from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp def check_parity_no_mp( @@ -31,7 +31,7 @@ def check_parity_no_mp( # TODO(future): add amax syncing once delayed scaling is supported optim.step() if model is fsdp_model and precompute: - precompute_float8_amax_for_fsdp(model) + precompute_float8_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) From ba085e53dc6025739ede8d3aad5166134a0e5a79 Mon Sep 17 00:00:00 2001 From: willfengg Date: Wed, 10 Jul 2024 16:36:53 -0700 Subject: [PATCH 25/27] add precompute scale in README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ff19b93..1ebab81 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,18 @@ model = FSDP(model, use_orig_params=True) # optional: enable torch.compile for improved performance m = torch.compile(m) -# train/finetune (not shown) +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + + # specific to fsdp2 + float8 with dynamic scaling + # this method is optional but is highly recommended for performance + # it calcuclates scales for all parameters in a single all-reduce + precompute_float8_scale_for_fsdp(model) + ``` ## float8 linear with delayed scaling From ac0afb0507828e3ee1e0e0dc8aec07f4412d87c1 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 11 Jul 2024 14:24:03 -0700 Subject: [PATCH 26/27] rename to precompute_float8_dynamic_scale_for_fsdp Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- README.md | 5 ++- float8_experimental/fsdp_utils.py | 55 ++++++++++++---------------- test/test_fsdp2/test_fsdp2_common.py | 4 +- 3 files changed, 28 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 1ebab81..464e9b1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically. from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from float8_experimental.float8_linear import Float8Linear # create model @@ -58,10 +59,10 @@ for _ in range(N_ITER): y.sum().backward() optimizer.step() - # specific to fsdp2 + float8 with dynamic scaling + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on # this method is optional but is highly recommended for performance # it calcuclates scales for all parameters in a single all-reduce - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) ``` diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index e06ec66..0ade173 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -1,31 +1,28 @@ import math -import warnings from typing import List import torch import torch.nn as nn from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor -from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import linear_requires_sync +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_utils import EPS -def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: """ - Calculate scale for all float8 parameters after optimizer step - It performs a single all-reduce instead of many all-reduces for each parameter - Exmaple usage: + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: model(input).sum().backward() optim.step() - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) """ from torch.distributed._tensor import DTensor if any( - isinstance(m, Float8Linear) - and linear_requires_sync( - m.scaling_type_x, m.scaling_type_w, m.scaling_type_dL_dY - ) + isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED for m in module.modules() ): raise NotImplementedError("Only supports delayed scaling") @@ -38,24 +35,18 @@ def precompute_float8_scale_for_fsdp(module: nn.Module) -> None: ] weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] - def compute_scales(weights: List[DTensor]): - # inf-norm is equivalent to max(abs(w)) - max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial - amax_tensor = torch.vstack(max_weights) # Partial - # clamp is dispatched through DTensor - # it will issue a single all-reduce - amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate - scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate - if amax_tensor.dtype is torch.float16: - scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - scales = torch.split(scale_tensor, 1) # Replicate - return scales + if not weights: + return - if weights: - scales = compute_scales(weights) - for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor - else: - warnings.warn( - "Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" - ) + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 5c7d21a..af57871 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,7 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp def check_parity_no_mp( @@ -31,7 +31,7 @@ def check_parity_no_mp( # TODO(future): add amax syncing once delayed scaling is supported optim.step() if model is fsdp_model and precompute: - precompute_float8_scale_for_fsdp(model) + precompute_float8_dynamic_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) From 8e56dfcc48c558d3e34cc531c94b2d195007da58 Mon Sep 17 00:00:00 2001 From: willfengg Date: Thu, 11 Jul 2024 14:25:50 -0700 Subject: [PATCH 27/27] rename to precompute_float8_dynamic_scale_for_fsdp Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- float8_experimental/float8_dynamic_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 3cdbaf9..b355098 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -110,7 +110,7 @@ def __init__( self._tensor = tensor self._mm_config = mm_config # for dynamic scaling - # `precompute_float8_scale_for_fsdp` calculates scales + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales # for all float8 parameters after optimizer step self._precomputed_scale = precomputed_scale