This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new sharding strategy
shard_sequence
for the transformer processor.The current implementation (
shard_heads
) alternates between sharding across the sequence to sharding across heads for the sliding window attention mechanism. This requires two all-to-all communication steps per layer.The shard_sequence strategy simplifies this process by keeping a sequence shard on each GPU and computing the sliding window attention locally. This requires a halo communication to exchange overlapping window segments (halos) between neighboring sequence shards.
Instead of 2 all-to-all communication steps per layer, the halo exchange only requires a single point-to-point communication between neighbouring GPUs, reducing communication time and improving scalability of model sharding across multiple GPUs.
The following benchmarking results show that using a 2 neighbor all-to-all (orange) is the best communication strategy to implement the halo exchange which consistently outperforms the old head-sharding strategy (blue):
This is an isolated fwd+bwds pass of 16 transformer layers with o96 input shapes, 1024 channels.
For a full training run on n320, o96 hidden we get the following increases in throughput (aligning with the benchmark results):
mlflow