Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp #331

Closed
wants to merge 1 commit into from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jul 25, 2024

Stack from ghstack (oldest at bottom):

Description
For Llama3-8B on 8xH100 profiling with with_stack=True (which does add overhead), the precompute_float8_dynamic_scale_for_fsdp CPU time decreases from 24 ms to 15 ms.

Before:
Screenshot 2024-07-25 at 10 16 38 AM

After:
Screenshot 2024-07-25 at 10 17 00 AM

Test Plan

(pytorch-3.10) [[email protected] /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py 
========================================================= test session starts =========================================================
platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0
rootdir: /data/users/andgu/float8_experimental
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0
collected 8 items                                                                                                                     
Running 8 items in this shard

test/test_fsdp2/test_fsdp2.py ........                                                                                          [100%]

========================================================== warnings summary ===========================================================
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity
  /data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead.
    all_reduced_amax_tensor = all_reduce(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ==============================================

Differential Revision: D60236258

awgu added a commit that referenced this pull request Jul 25, 2024
ghstack-source-id: 9a22b865feaf67bb83910a99503975d09170fa06
Pull Request resolved: #331
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 25, 2024
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
)
local_scale_tensor = scale_tensor.to_local()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to call to_local() here because DTensor does not support [i] int indexing. Int indexing might not be semantically clear if the DTensor is sharded; I think indexing should be on the local tensor.

@awgu awgu marked this pull request as ready for review July 25, 2024 14:20
@awgu
Copy link
Contributor Author

awgu commented Jul 25, 2024

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
amax_tensor = torch.stack(max_weights) # Partial
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a TLDR for me, I like vstack because semantically I think about gluing lego blocks together lol Does it assert some contiguousnous that causes it to be less performant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline: vstack incurs per-tensor reshape, which each redispatches through DTensor dispatch, which is what makes it slow. stack only goes through DTensor dispatch once.

@facebook-github-bot
Copy link
Contributor

@awgu merged this pull request in 701647b.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants