-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add initial support for sequence parallelism #1436
base: main
Are you sure you want to change the base?
Conversation
c263cb3
to
71c8afe
Compare
From the code i see, the prefill stage after attention, the shape of output is [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim], and then out * RowSeqParallelLinear which need use allreduce. the input of qkv_proj_linear is [padded_total_num_tokens, q_head_num, head_dim] which not spilted by sp_size. i want to know why done use ring attention , ring attention seems better then it in both Computing and Communication. |
For each SP worker, we have either (1) QKV of entire sequences: | ||
q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] | ||
k tensor: [padded_total_num_tokens, k_head_num, head_dim] | ||
v tensor: [padded_total_num_tokens, v_head_num, head_dim] | ||
Or (2) Q of entire sequences and KV of the current SP shard: | ||
q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim] | ||
k tensor: [padded_sp_shard_num_tokens, k_head_num, head_dim] | ||
v tensor: [padded_sp_shard_num_tokens, v_head_num, head_dim] | ||
|
||
Case (1) saves cross-SP-worker communication, while case (2) saves computation | ||
to get K and V for entire sequences but need computation in SP attn. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(2) seems to be able to split workload and overlap even with single query. But just curious, does anyone have opinions on TreeAttention (just all-reduce lse instead of sending KV), which seems optimized for decoding?
# TODO: in fact we can use all-to-all to gather the output and state here | ||
# to collect only q head shards that are needed by the current SP worker. | ||
# All-to-all will save communication and `merge_state` computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later all-reduce in ColumnSeqParallelLinear
? Thx
@ZYHowell @ivanium Moved from #1041