Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ensure capacity does not exceed number of tokens (deepspeedai#5353)
When fine-tuning we were running into issues where the capacity would trigger the following error after some amount of time training. This was caused when the size of the inputs to top1gating were not aligned between ranks. ``` ... File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 427, in forward gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 240, in top1gating top_idx = _top_idx(mask1_rand, capacity) RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): File "/shared/users/jrasley/DeepSpeed/deepspeed/moe/sharded_moe.py", line 172, in _top_idx @torch.jit.script def _top_idx(source, k): return torch.topk(source, k=k, dim=0)[1] ~~~~~~~~~~ <--- HERE RuntimeError: selected index k out of range ``` Co-authored with: @rajhans Reviewed/approved by: @samyam, @yaozhewei Tagging @tohtana and @ykim362 to help review
- Loading branch information