-
Notifications
You must be signed in to change notification settings - Fork 458
Add context parallelism #1445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context parallelism #1445
Conversation
b342894 to
6b9c29d
Compare
3f55480 to
3ba3d54
Compare
4c4ef47 to
1b65e2e
Compare
gobbleturk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
go anisha go
f6d6f39 to
4169d41
Compare
9f18043 to
c6ea0d7
Compare
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
c6ea0d7 to
f4926bf
Compare
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the change!
f4926bf to
c4ead2f
Compare
richjames0
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm again
c4ead2f to
fbe8c2d
Compare
fbe8c2d to
4e7d847
Compare
Description
Adding Context Parallelism to MaxText to support long context length in training on TPUs, following PR 1133 which adds it for GPUs
This is because we want to support the common context parallelism paradigm of AG keys and values, shard on query sequence.
Previously, MaxText implemented sequence parallelism by sharding the attention operations on heads (the FF operation being sharded by sequence). This decision was made because the heads dim acts like a batch dimension. This PR's solution frees heads dim to be used for other model parallelism.
We also ensure that we employ "load balancing" in the context parallelism - meaning, we reorder the tokens and the attention mask such that in attention all devices have somewhat similar amount of work (and hence time taken) other wise, with say context parallelism=2, device which receives shard with say token 0,1,2,3 have much lesser work than device which received shard with say token 4,5,6,7, so we load balance by dividing the token as 0,1,6,7 and 2,3,4,5
We use
jnp.takein the load balancing, which hopefully can be improved and is being tracked in b/413770626FIXES: b/377904983
Tests
Tested locally on v5p-8,
and also on v6e-256 for Llama3.1-70b with context length of 131072 tokens
python3 -m benchmarks.maxtext_xpk_runnerllama3_1_70b_131072frommaxtext/benchmarks/maxtext_trillium_model_configs.pyAlso added unit test for context parallelism in
MaxText/tests/attention_test.pyUpdated
maxtext/end_to_end/tpu/llama2/7b/test_llama2_7b.shto double as integration test for context parallelism for training intrain.pyviaici_context_parallelism.Also, used
ici_context_parallelismforforward_pass_logit_checkerto test forward pass since we can't use flash attention indecode.pyChecklist
Before submitting this PR, please make sure (put X in square brackets):