From 6fc14a3c1a1400b2efc8a1d0da34c1a97a720d8a Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 25 Jul 2024 07:13:37 -0700 Subject: [PATCH] Reduced CPU overhead in `precompute_float8_dynamic_scale_for_fsdp` [ghstack-poisoned] --- float8_experimental/fsdp_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index d9fd200..eef9ec1 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial - amax_tensor = torch.vstack(max_weights) # Partial + amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - scales = torch.split(scale_tensor, 1) # Replicate - for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = ( - scale._local_tensor.squeeze() - ) + local_scale_tensor = scale_tensor.to_local() + for i, float8_linear in enumerate(float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] # FSDP pads its local tensor on dim-0. The subclass should be preserved such