-
Notifications
You must be signed in to change notification settings - Fork 459
Adding support for sequence sharding with tensor parallelism #1136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for sequence sharding with tensor parallelism #1136
Conversation
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, thank you for this contribution!
Can we either
- Add an assert that checks that tensor_sequence is only used with llama2 decoder block
- Add support for tensor_sequence for all decoder blocks
|
@gobbleturk done (2). Thank you for your feedback. |
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thank you!
|
Fixed issues related to the linter. |
…avgoel95/maxtext into abgoel/tensor-sequence-parallelism
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you need to add tensor_sequence in pyconfig.py in at least 3 places (can search tensor_parallelism for relevant places)
| raw_keys["dcn_fsdp_transpose_parallelism"], | ||
| raw_keys["dcn_sequence_parallelism"], | ||
| raw_keys["dcn_tensor_parallelism"], | ||
| raw_keys["dcn_tensor_sequence_parallelism"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add "tensor_sequence" to the mesh axes and data axes lists for PP right below
Lines 599 to 600 in 30f3988
| mesh_axes = ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"] | |
| data_sharding = [["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "expert", "autoregressive"]] |
577ca8d
into
AI-Hypercomputer:main
Description
This change adds support for sequence sharding along the tensor parallel axis. The benefits of this approach are available in this paper. This has performance benefits over the default
tensorsharding approach because we do not need an all-reduce when computing LayerNorms or RMSNorms.Tests
I have run the llama2-7b config with this change and observed a 7% end-to-end speedup on 8x H100 GPUs.
Baseline:
A snippet of output:
With changes:
Checklist
Before submitting this PR, please make sure (put X in square brackets):