Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dist attn reshape error #5366

Closed
wants to merge 5 commits into from

Conversation

tkdcjf159
Copy link

By default, DeepSpeed's DistributedAttention is set with scatter_idx = 2 and gather_idx = 0. However, if I set gather_idx to 1 and have a batch size greater than 1, an error will occur during the output all to all operation, as illustrated below. To fix this, modify the seq_world_size to -1.

def single_all_to_all(input, scatter_idx, gather_idx, group):
    # Assume input shape [2, 1024, 8, 16], scatter_idx = 1, gather_idx=2, seq_world_size=8
    seq_world_size = dist.get_world_size(group)
    inp_shape = list(input.shape) # inp_shape = [2, 1024, 8, 16]
    inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size # inp_shape = [2, 128, 8, 16]
    if scatter_idx < 2:
        # Reshaping from [2, 1024, 8, 16] to [8, 128, 8, 16]: ERROR! (2 * 1024 * 8 * 16) != (8 * 128 * 8 * 16)
        # Use -1 to fix issue
        input_t = input.reshape(
            [-1, inp_shape[scatter_idx]] + \
            # [seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).contiguous()
    else:
        # Transpose groups of heads with the seq-len parallel dimension to scatter them
        input_t = input.reshape(
            [-1, seq_world_size, inp_shape[scatter_idx]] + \
            inp_shape[scatter_idx + 1:]
        ).transpose(0, 1).contiguous()

    output = torch.empty_like(input_t)
    dist.all_to_all_single(output, input_t, group=group)

    # If scattering the seq-dim, transpose the heads back to the original dimension
    if scatter_idx < 2:
        output = output.transpose(0, 1).contiguous()

    return output.reshape(
        inp_shape[: gather_idx] + \
        [inp_shape[gather_idx] * seq_world_size,] + \
        inp_shape[gather_idx + 1:]).contiguous()

@tkdcjf159 tkdcjf159 requested a review from mrwyattii as a code owner April 5, 2024 00:12
@tkdcjf159 tkdcjf159 requested review from awan-10 and arashb as code owners April 5, 2024 00:41
@tkdcjf159
Copy link
Author

@microsoft-github-policy-service agree company="Upstage"

@loadams
Copy link
Collaborator

loadams commented Jan 7, 2025

@tkdcjf159 - if you're still interested in this PR could you resolve the conflicts and we will get it reviewed?

@loadams loadams self-assigned this Jan 7, 2025
@loadams loadams requested review from loadams and removed request for arashb, awan-10 and mrwyattii January 7, 2025 16:51
@loadams
Copy link
Collaborator

loadams commented Jan 21, 2025

@tkdcjf159 - closing this PR as this code has been refactored, if you believe this bug still remains, please comment and we can re-open this PR.

@loadams loadams closed this Jan 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants