Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/transformer sequence sharding #90

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from

Conversation

japols
Copy link
Member

@japols japols commented Nov 28, 2024

This PR adds a new sharding strategy shard_sequence for the transformer processor.

The current implementation (shard_heads) alternates between sharding across the sequence to sharding across heads for the sliding window attention mechanism. This requires two all-to-all communication steps per layer.

The shard_sequence strategy simplifies this process by keeping a sequence shard on each GPU and computing the sliding window attention locally. This requires a halo communication to exchange overlapping window segments (halos) between neighboring sequence shards.

Instead of 2 all-to-all communication steps per layer, the halo exchange only requires a single point-to-point communication between neighbouring GPUs, reducing communication time and improving scalability of model sharding across multiple GPUs.

The following benchmarking results show that using a 2 neighbor all-to-all (orange) is the best communication strategy to implement the halo exchange which consistently outperforms the old head-sharding strategy (blue):

sharding_strategies

This is an isolated fwd+bwds pass of 16 transformer layers with o96 input shapes, 1024 channels.

For a full training run on n320, o96 hidden we get the following increases in throughput (aligning with the benchmark results):

GPUs/Model sharding strategy avg time/batch (s)
2 shard_heads 1.38495
2 shard_sequence 1.29771
4 shard_heads 0.72034
4 shard_sequence 0.69254

mlflow

@FussyDuck
Copy link

FussyDuck commented Nov 28, 2024

CLA assistant check
All committers have signed the CLA.

@japols japols self-assigned this Nov 28, 2024
@codecov-commenter
Copy link

codecov-commenter commented Nov 28, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.85%. Comparing base (225315e) to head (063844b).

Additional details and impacted files
@@           Coverage Diff            @@
##           develop      #90   +/-   ##
========================================
  Coverage    99.85%   99.85%           
========================================
  Files           23       23           
  Lines         1374     1374           
========================================
  Hits          1372     1372           
  Misses           2        2           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@japols japols force-pushed the feature/transformer_sequence_sharding branch from 4ab4205 to a847f1a Compare November 29, 2024 10:55
@japols japols requested review from ssmmnn11 and mishooax December 17, 2024 18:44
@japols japols marked this pull request as ready for review December 17, 2024 18:46
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants