Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp #331

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
amax_tensor = torch.stack(max_weights) # Partial
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a TLDR for me, I like vstack because semantically I think about gluing lego blocks together lol Does it assert some contiguousnous that causes it to be less performant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline: vstack incurs per-tensor reshape, which each redispatches through DTensor dispatch, which is what makes it slow. stack only goes through DTensor dispatch once.

# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
)
local_scale_tensor = scale_tensor.to_local()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to call to_local() here because DTensor does not support [i] int indexing. Int indexing might not be semantically clear if the DTensor is sharded; I think indexing should be on the local tensor.

for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down
Loading