-
Notifications
You must be signed in to change notification settings - Fork 20
Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp
#331
Conversation
[ghstack-poisoned]
ghstack-source-id: 9a22b865feaf67bb83910a99503975d09170fa06 Pull Request resolved: #331
float8_linear.weight._local_tensor._precomputed_scale = ( | ||
scale._local_tensor.squeeze() | ||
) | ||
local_scale_tensor = scale_tensor.to_local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to call to_local()
here because DTensor
does not support [i]
int indexing. Int indexing might not be semantically clear if the DTensor
is sharded; I think indexing should be on the local tensor.
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: | |||
|
|||
# inf-norm is equivalent to max(abs(w)) | |||
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial | |||
amax_tensor = torch.vstack(max_weights) # Partial | |||
amax_tensor = torch.stack(max_weights) # Partial |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a TLDR for me, I like vstack because semantically I think about gluing lego blocks together lol Does it assert some contiguousnous that causes it to be less performant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From offline: vstack
incurs per-tensor reshape
, which each redispatches through DTensor
dispatch, which is what makes it slow. stack
only goes through DTensor
dispatch once.
Stack from ghstack (oldest at bottom):
precompute_float8_dynamic_scale_for_fsdp
#331Description
For Llama3-8B on 8xH100 profiling with
with_stack=True
(which does add overhead), theprecompute_float8_dynamic_scale_for_fsdp
CPU time decreases from 24 ms to 15 ms.Before:
After:
Test Plan
Differential Revision: D60236258