Skip to content

[BUG] [Fix Suggestion] Uneven head sequence parallelism #6774

Closed
@Eugene29

Description

Describe the bug

deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"

This happens because during the second all2all, the head count is already parallelized; hence, num_heads % seq_world_size != 0 returns true.
Second all2all input: [B, s, hc/sp, hs]. However, not always hc/sp % sp == 0.

def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
    seq_world_size = dist.get_world_size(group)
    # we only need num_heads once
    num_heads = input.shape[2]

    if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
        # Assuming here that the number of heads for q is consistent with kv
        # If not, additional logic is required for cases like GQA
        if get_num_kv_heads() is None:
            assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
            # set heads at first call by num_total_heads.
            # then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
            set_num_kv_heads(num_heads)
        assert async_op == False, "uneven head sp does not support async op"
        return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

To Reproduce

To reproduce the error, one can set the SP=head_count.
Fix Suggestion:
Only rely on get_num_kv_heads() during the second all2all:

    if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:

to

    if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and gather_idx < 2):

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions