Skip to content

[BUG] [Fix Suggestion] Uneven head sequence parallelism #6774

Closed
@Eugene29

Description

Describe the bug

deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"

This happens because during the second all2all, the head count is already parallelized; hence, num_heads % seq_world_size != 0 returns true.
Second all2all input: [B, s, hc/sp, hs]. However, not always hc/sp % sp == 0.

def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
    seq_world_size = dist.get_world_size(group)
    # we only need num_heads once
    num_heads = input.shape[2]

    if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
        # Assuming here that the number of heads for q is consistent with kv
        # If not, additional logic is required for cases like GQA
        if get_num_kv_heads() is None:
            assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
            # set heads at first call by num_total_heads.
            # then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
            set_num_kv_heads(num_heads)
        assert async_op == False, "uneven head sp does not support async op"
        return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

To Reproduce

To reproduce the error, one can set the SP=head_count.
Fix Suggestion:
Only rely on get_num_kv_heads() during the second all2all:

    if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:

to

    if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and gather_idx < 2):

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions