diff --git a/README.md b/README.md index ff19b93..464e9b1 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically. from float8_experimental.float8_linear_utils import ( swap_linear_with_float8_linear, ) +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp from float8_experimental.float8_linear import Float8Linear # create model @@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True) # optional: enable torch.compile for improved performance m = torch.compile(m) -# train/finetune (not shown) +# toy training loop +for _ in range(N_ITER): + optimizer.zero_grad() + y = m(x) + y.sum().backward() + optimizer.step() + + # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on + # this method is optional but is highly recommended for performance + # it calcuclates scales for all parameters in a single all-reduce + precompute_float8_dynamic_scale_for_fsdp(model) + ``` ## float8 linear with delayed scaling diff --git a/float8_experimental/float8_dynamic_utils.py b/float8_experimental/float8_dynamic_utils.py index 7f44363..b355098 100644 --- a/float8_experimental/float8_dynamic_utils.py +++ b/float8_experimental/float8_dynamic_utils.py @@ -82,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw( class WeightWithDynamicFloat8CastTensor(torch.Tensor): @staticmethod - def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __new__( + cls, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): return torch.Tensor._make_wrapper_subclass( cls, tensor.size(), @@ -96,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): requires_grad=tensor.requires_grad, ) - def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): + def __init__( + self, + tensor: torch.Tensor, + mm_config: ScaledMMConfig, + precomputed_scale: Optional[torch.Tensor] = None, + ): self._tensor = tensor self._mm_config = mm_config + # for dynamic scaling + # `precompute_float8_dynamic_scale_for_fsdp` calculates scales + # for all float8 parameters after optimizer step + self._precomputed_scale = precomputed_scale @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -127,20 +141,35 @@ def unwrap(t): ) def __tensor_flatten__(self): - return ["_tensor"], self._mm_config + if self._precomputed_scale: + return ["_tensor", "_precomputed_scale"], self._mm_config + else: + return ["_tensor"], self._mm_config @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): mm_config = flatten_spec - return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config) + return WeightWithDynamicFloat8CastTensor( + inner_tensors["_tensor"], + mm_config, + getattr(inner_tensors, "_precomputed_scale", None), + ) def __repr__(self): return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" def fsdp_pre_all_gather(self, mesh): - float8_tensor = cast_to_float8_e4m3_dynamic( - self._tensor, self._mm_config, reduce_amax=True - ) + if self._precomputed_scale is not None: + float8_tensor = Float8Tensor.to_float8( + self._tensor, + self._precomputed_scale, + torch.float8_e4m3fn, + mm_config=self._mm_config, + ) + else: + float8_tensor = cast_to_float8_e4m3_dynamic( + self._tensor, self._mm_config, reduce_amax=True + ) return (float8_tensor._data,), (float8_tensor._scale,) def fsdp_post_all_gather( diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py new file mode 100644 index 0000000..0ade173 --- /dev/null +++ b/float8_experimental/fsdp_utils.py @@ -0,0 +1,52 @@ +import math +from typing import List + +import torch +import torch.nn as nn +from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_utils import EPS + + +@torch.no_grad() +def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: + """ + Calculate scale dynamically for all float8 parameters. + This should be run after the optimizer step. It performs a single all-reduce to compute the + scales for all float8 weights. + Example usage: + model(input).sum().backward() + optim.step() + precompute_float8_dynamic_scale_for_fsdp(model) + """ + from torch.distributed._tensor import DTensor + + if any( + isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED + for m in module.modules() + ): + raise NotImplementedError("Only supports delayed scaling") + float8_linears: List[Float8Linear] = [ + m + for m in module.modules() + if isinstance(m, Float8Linear) + and isinstance(m.weight, DTensor) + and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) + ] + weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] + + if not weights: + return + + # inf-norm is equivalent to max(abs(w)) + max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial + amax_tensor = torch.vstack(max_weights) # Partial + # clamp is dispatched through DTensor + # it will issue a single all-reduce + amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate + if amax_tensor.dtype is torch.float16: + scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) + scales = torch.split(scale_tensor, 1) # Replicate + for scale, float8_linear in zip(scales, float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor diff --git a/test/test_fsdp2/test_fsdp2_common.py b/test/test_fsdp2/test_fsdp2_common.py index 9d42b56..af57871 100644 --- a/test/test_fsdp2/test_fsdp2_common.py +++ b/test/test_fsdp2/test_fsdp2_common.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp def check_parity_no_mp( @@ -15,6 +16,7 @@ def check_parity_no_mp( fsdp_model: nn.Module, fsdp_optim: torch.optim.Optimizer, local_inp: torch.Tensor, + precompute: bool = False, ): for iter_idx in range(10): losses: List[torch.Tensor] = [] @@ -28,6 +30,8 @@ def check_parity_no_mp( param.grad.div_(dist.get_world_size()) # TODO(future): add amax syncing once delayed scaling is supported optim.step() + if model is fsdp_model and precompute: + precompute_float8_dynamic_scale_for_fsdp(model) test_cls.assertEqual(losses[0], losses[1]) diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 5ca483f..bdbc878 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -86,10 +86,21 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + self.run_subtests( + { + "enable_fsdp_fp8_all_gather": [False, True], + "precompute": [False, True], + }, + self._test_transformer_parity_dynamic, + ) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity_dynamic( + self, + enable_fsdp_fp8_all_gather: bool, + precompute: bool, + ): + if not enable_fsdp_fp8_all_gather and precompute: + return # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -109,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) - check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp) + check_parity_no_mp( + self, ref_module, ref_optim, module, optim, local_inp, precompute + ) @skip_if_lt_x_gpu(2) def test_transformer_memory(self):