diff --git a/float8_experimental/fsdp_utils.py b/float8_experimental/fsdp_utils.py index d9fd200..eef9ec1 100644 --- a/float8_experimental/fsdp_utils.py +++ b/float8_experimental/fsdp_utils.py @@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # inf-norm is equivalent to max(abs(w)) max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial - amax_tensor = torch.vstack(max_weights) # Partial + amax_tensor = torch.stack(max_weights) # Partial # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate if amax_tensor.dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - scales = torch.split(scale_tensor, 1) # Replicate - for scale, float8_linear in zip(scales, float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = ( - scale._local_tensor.squeeze() - ) + local_scale_tensor = scale_tensor.to_local() + for i, float8_linear in enumerate(float8_linears): + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] # FSDP pads its local tensor on dim-0. The subclass should be preserved such