Skip to content

Commit

Permalink
Ensure capacity does not exceed number of tokens (deepspeedai#5353)
Browse files Browse the repository at this point in the history
When fine-tuning we were running into issues where the capacity would
trigger the following error after some amount of time training. This was
caused when the size of the inputs to top1gating were not aligned
between ranks.

```
...
  File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 427, in forward
    gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor,
  File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 240, in top1gating
    top_idx = _top_idx(mask1_rand, capacity)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 172, in _top_idx
@torch.jit.script
def _top_idx(source, k):
    return torch.topk(source, k=k, dim=0)[1]
           ~~~~~~~~~~ <--- HERE
RuntimeError: selected index k out of range
```

Co-authored with: @rajhans
Reviewed/approved by: @samyam, @yaozhewei 

Tagging @tohtana and @ykim362 to help review
  • Loading branch information
jeffra authored and dbyoung18 committed Jun 11, 2024
1 parent fbcc895 commit 2a6bfb5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion deepspeed/moe/sharded_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,15 @@ def top1gating(logits: Tensor,
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
# Communicate across all processes to pick the maximum capacity.
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
if groups._get_expert_model_parallel_world_size() == 1:
# If the non-expert is tensor-parallel, we need to pad the capacity to 'tp'.
# This is since we are going to activate drop_tokens() to drop duplicate tokens.
tp = 1 if groups.mpu is None else groups.mpu.get_tensor_model_parallel_world_size()
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity
# Make sure the capacity value does not exceed the number of tokens.
capacity = min(new_capacity, torch.tensor(mask1.size(0)))

# Compute l_aux
me = torch.mean(gates, dim=0)
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import gc
from unit.common import DistributedTest
from unit.simple_model import SimplePRMoEModel, SimpleMoEModel, sequence_dataloader
import deepspeed.comm as dist
from deepspeed import get_accelerator
from deepspeed.moe.sharded_moe import top1gating
from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer, is_moe_param
from deepspeed.runtime.utils import required_torch_version

Expand Down Expand Up @@ -132,3 +135,23 @@ def test(self, ep_size, use_residual):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()


class TestTopk(DistributedTest):
world_size = 2

def test(self):
device = get_accelerator().current_device()
if dist.get_rank() == 0:
logits = torch.rand(2, 2, device=device)
elif dist.get_rank() == 1:
logits = torch.rand(10, 2, device=device)

output = top1gating(logits=logits,
capacity_factor=1,
min_capacity=0,
used_token=None,
noisy_gate_policy=None,
drop_tokens=False,
use_rts=True,
use_tutel=False)

0 comments on commit 2a6bfb5

Please sign in to comment.