[BUG] [Fix Suggestion] Uneven head sequence parallelism #6774
Closed
Description
opened on Nov 21, 2024
Describe the bug
deepspeed 0.15.4 will think you are using unevenhead SP even though you aren't and raise the following assert:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
This happens because during the second all2all, the head count is already parallelized; hence, num_heads % seq_world_size != 0
returns true.
Second all2all input: [B, s, hc/sp, hs]
. However, not always hc/sp % sp == 0
.
def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, async_op=False, handle=None, type=None):
seq_world_size = dist.get_world_size(group)
# we only need num_heads once
num_heads = input.shape[2]
if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
# Assuming here that the number of heads for q is consistent with kv
# If not, additional logic is required for cases like GQA
if get_num_kv_heads() is None:
assert num_heads > seq_world_size, f"Number of heads ({num_heads}) must be larger than sequence parallel size ({seq_world_size})"
# set heads at first call by num_total_heads.
# then use ``get_num_kv_heads() is not None`` to re-entry uneven path.
set_num_kv_heads(num_heads)
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)
To Reproduce
To reproduce the error, one can set the SP=head_count.
Fix Suggestion:
Only rely on get_num_kv_heads()
during the second all2all:
if get_num_kv_heads() is not None or num_heads % seq_world_size != 0:
to
if get_num_kv_heads() is not None or (num_heads % seq_world_size != 0 and gather_idx < 2):
Activity