-
Notifications
You must be signed in to change notification settings - Fork 460
Add Context Parallelism support to cudnn Flash Attention #1133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Context Parallelism support to cudnn Flash Attention #1133
Conversation
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, this looks great! Just one change to the sharding rules (with current state of this PR it looks like the old sequence parallelism is broken)
I'd appreciate it if you could also run bash code_style.sh to run our linter (I recommend saving the branch before running this just in case...)
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! thank you so much
Curious, what kind of improvement are you observing with the loadbalancing enabled?
|
Oh also please merge all commits into one before pushing |
for llama3-8b, with load balancing and cp=2, we're able to see around ~10% perf improvement |
|
@gobbleturk @A9isha once I get a LGTM from you, I can work on resolving the merge conflicts and also merge all the commits into one. |
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I would like either @khatwanimohit or @aireenmei to review for the data loading change part enable_packing=True
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thank you!
a3e1739 to
5016d35
Compare
|
Hi @gobbleturk could you please approve so that the unit tests can run? |
5016d35 to
dbf3d70
Compare
1. Implemented input sequence re-ordering at the data loading stage 2. Added context_parallel_load_balance flag in base.yml to turn on/off load balancing. 3. Added sequence packing flag packing in base.yml and also modified the associated data processing files. 4. Added/modified unit tests for context parallelism and GPU flash attention cudnn_flash_te 5. Added sharding of parameters across CP rank (modified the logical axis sharding rules) 6. fixed the q,k,v and out_proj sharding axis names by adding MODEL_MODE_TRAIN in attention.py
dbf3d70 to
336fb38
Compare
A9isha
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
aireenmei
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for data input changes!
|
All tests passed, manually adding pull ready |
|
pull ready might have added successfully automatically after an additional allowing step by me to run next steps, not sure |
c4f7060
into
AI-Hypercomputer:main
Description
This PR adds Context Parallelism support to GPU Flash Attention. It is necessary to support large sequence lengths in MaxText. Right now, the support is offered through Transformer-Engine and uses an All-Gather type implementation. Note that, it requires mask type to be
causaland does not work with sliding window attention yet. Also it requirestransformer-engine==1.13or above.NEW
context_parallel_load_balanceflag inbase.ymlto turn on/off load balancing.enable_packinginbase.ymland also modified the associated data processing files.cudnn_flash_teTests
Unit test is included with the PR with the base model for
4 x a100gpus,Checklist