-
Notifications
You must be signed in to change notification settings - Fork 458
Add support for EP to context parallelism in self-attention #2023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
e719d35 to
18d8769
Compare
9e07f65 to
9eb87ff
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Shuning! Great work!
e943bff to
b5241dc
Compare
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you considered an approach like conditionally modifying the rules (instead of creating new ones?) This is an approach used for pipeline parallelism
Line 788 in fdf479f
| def modify_activation_embed_and_logits_batch(logical_axis_rules): |
there are pros and cons of both, both pretty ugly IMO but at least when modifying rules there are
- less rules
- no if statements which rules to use
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really impressive that you understood this and got it working!
a9f3ea9 to
40f8ad2
Compare
3dd6196 to
251a4ce
Compare
|
Resolved merge conflict with nnx migration for attention layer
Re-testing on local v5p-8, diff (before vs. after nnx migration)
|
Description
Goal
For mixture of expert models, we may use expert parallelism. For attention layer, EP acts as FSDP currently. Built upon previous context parallelism work, this PR is to introduce the option of using EP as CP for attention. This is joint effort with @RissyRan.
FIXES: b/418396648
Main code changes
attentions.pybase.ymlexpert_shard_attention_option: fsdp or contextunit test:
tests.attention_testAttentionTest.test_tpu_flash_attention_cp_and_ep&MLATest.test_tpu_flash_attention_cp_and_ep(extended from cp test)Use case
ici_expert_parallelism=4, ici_context_parallelism=1, expert_shard_attention_option=context, shard context by 4 in attention, shard expert by 4 for moeici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=context, shard context by 4 in attention, shard expert by 2 and context by 2 for moeTests
Tested on v5p-8
Verify sharding shape
context_parallel_load_balance=True, parallelism:ici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=contextandici_expert_parallelism=2, ici_context_parallelism=2, expert_shard_attention_option=fsdpVerify attention output logit against dot product
context_parallel_load_balance={True, False}, parallelism: {ici_expert_parallelism=4, expert_shard_attention_option=context,ici_expert_parallelism=2, expert_shard_attention_option=context, ici_context_parallelism=2}Checklist
Before submitting this PR, please make sure (put X in square brackets):