-
Notifications
You must be signed in to change notification settings - Fork 620
1M+ context length (context parallel integration) #2668
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 93085e3 with merge base 0d90675 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -718,3 +743,108 @@ def prepare_mha_for_tp( | |||
if is_fusion_model: | |||
model.decoder = decoder | |||
return model | |||
|
|||
|
|||
def _get_sdpa_context() -> ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean CP doesn't work with FlexAttention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, at least until pytorch/pytorch#151497 lands
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious if torchtune already has non-CP flex_attention? cc @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recipes/full_finetune_distributed.py
Outdated
# Define optional context manager for context parallelism | ||
model_inputs = list(batch.values()) | ||
buffers = list(self._model.buffers()) | ||
optional_context_parallel_context_manager = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm good taking out the "optional" here and matching what we do for activation offloading
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2668 +/- ##
==========================================
+ Coverage 7.75% 60.02% +52.27%
==========================================
Files 376 437 +61
Lines 23117 26765 +3648
==========================================
+ Hits 1792 16066 +14274
+ Misses 21325 10699 -10626 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for enabling this! Left some feedback about how the manager is structured, and some questions on whether we're safely handling all of our edge cases.
torchtune/training/_distributed.py
Outdated
|
||
def get_context_parallel_context( | ||
*, | ||
cp_enabled: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should just be "enabled", cp seems redundant in this function
torchtune/training/_distributed.py
Outdated
# TODO: context parallel for multimodal models requires extra work | ||
if cp_enabled and not isinstance(model, TransformerDecoder): | ||
raise ValueError( | ||
"Context parallel not supported for models other than TransformerDecoder" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this error message should just state that "Only text models are supported" as a user might not understand why they're getting this error.
torchtune/training/_distributed.py
Outdated
# Create context parallel context if enabled | ||
cp_context = None | ||
if ( | ||
cp_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just return a NullContext if not enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do that too, but this gives us the option to explicitly order the SDPBackends as we prefer them, so personally I think it's useful.
torchtune/training/_distributed.py
Outdated
cp_context = None | ||
if ( | ||
cp_enabled | ||
and world_mesh is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
world_mesh and model_inputs are not optional inputs, do we need this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, will remove
# Create and enter the train context with the optional cp_context | ||
sdpa_context = _get_sdpa_context() | ||
|
||
with sdpa_context(cp_context): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens with llama4? Or with packed data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest updates should address this. For packing we don't know until forward, so unfortunately have to check inside of context
. For Llama4 we check straightaway whether any layer's mask_mod
is not None (though also we will error anyways since it's an early fusion model).
torchtune/training/__init__.py
Outdated
@@ -11,6 +11,7 @@ | |||
from torchtune.training._compile import compile_loss, compile_model | |||
from torchtune.training._distributed import ( | |||
gather_cpu_state_dict, | |||
get_context_parallel_context, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_context_parallel_manager would read nicer
recipes/qat_distributed.py
Outdated
collate_fn=( | ||
partial( | ||
collate_fn, | ||
padding_idx=self._tokenizer.pad_id, | ||
ignore_idx=self._loss_fn.ignore_index, | ||
pad_to_multiple_of=self.tp_degree, | ||
pad_to_multiple_of=self.tp_degree * self.cp_degree * 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this value could come from parallel dims? Like a property?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a good idea
recipes/full_finetune_distributed.py
Outdated
@@ -911,6 +919,16 @@ def train(self) -> None: | |||
|
|||
utils.batch_to_device(batch, self._device) | |||
|
|||
# Define optional context manager for context parallelism | |||
context_parallel_context_manager = ( | |||
training.get_context_parallel_context( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the recipe setup, we should call self.cp_manager = training.get_context_parallel_manager(enabled, mesh, model)
. Then during training, we can just do with self.cp_manager(batch):
torchtune/training/_distributed.py
Outdated
with sdpa_context(cp_context): | ||
yield | ||
|
||
return context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should return a context manager function uninitialized that is expecting data as an input, so you can call with context(batch):
in your train loop
torchtune/training/_distributed.py
Outdated
cp_enabled: bool = False, | ||
world_mesh: torch.distributed.DeviceMesh, | ||
model: TransformerDecoder, | ||
model_inputs: list[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should take a dict, since all our recipes use a nested dict for inputs. Also, you can't just call list(batch.values) since there are nested dicts. Since we don't support multimodal for now, you can just explicitly check that tokens are inside the batch in this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nested dicts only happen with multimodal, right? Actually doesn't TensorDict have some utilities for some stuff like this 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, I am gonna leave this as is for now since we are already checking that we are in a text-only regime when initializing the context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamping to unblock but please fix the comments
@@ -158,6 +158,7 @@ def __init__(self, cfg: DictConfig) -> None: | |||
raise ValueError( | |||
"Tensor Parallel plan needs to be provided when tensor parallel is enabled." | |||
) | |||
self.cp_degree = cfg.get("context_parallel_dim", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put this as a default in one of the full finetune distributed recipes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly idk if it makes sense. Like for Alpaca without packing there's really no reason to cause the sequences are so short
@@ -603,6 +605,10 @@ def _setup_model( | |||
"FP8 training does not support tensor parallelism yet. " | |||
"This will be enabled in the near future." | |||
) | |||
if self.cp_degree > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: access this through parallel_dims
torch.distributed.all_reduce(num_tokens) | ||
torch.distributed.all_reduce(running_loss) | ||
current_loss = current_loss * (self.dp_degree / num_tokens) | ||
with self.context_parallel_manager(list(batch.values())): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to do this? Couldn't you have the manager take in any iterable, including a generator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess? But list is a more natural datatype, no? And also that's ultimately what core's context manager expects anyways so I would prefer to stay close to that
@@ -4,6 +4,7 @@ | |||
# This source code is licensed under the BSD-style license found in the | |||
# LICENSE file in the root directory of this source tree. | |||
from functools import partial | |||
from typing import List |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no. bad. use list
.
@@ -15,12 +16,14 @@ | |||
the llama3_2_1b model builder uses the llama3_2 component builder to create the | |||
Llama3.2 1B model. | |||
""" | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk i think it's just lint adding missing spaces between functions (idk why they weren't there to begin with)
@@ -67,6 +72,8 @@ def llama3_2_3b( | |||
scale_factor=32, | |||
tie_word_embeddings=tie_word_embeddings, | |||
) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nonsense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these clearly should have been here already, no?
torchtune/training/_distributed.py
Outdated
return mesh | ||
|
||
@property | ||
def enabled(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be cp_enabled
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blargh i did find and replace all in the file, sorry. will fix
torchtune/training/_distributed.py
Outdated
enabled=parallel_dims.enabled, | ||
cp_mesh=world_mesh["cp"] if parallel_dims.enabled else None, | ||
model_inputs=list(batch.values()), | ||
model_buffers=model.buffers(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agh I need to update this. Do you not find the example informative? If not I can take it out
Initial implementation of context parallelism in torchtune. We add a utility
get_context_parallel_context
to return a context manager whenever context parallelism is enabled (otherwise it's basically nullcontext). We also update our collate function to pad to 2 * context_parallel_dim (needed due to this).Tested
Still needs work (i.e. not addressed in this PR)
🚧 Composability with flex attention (PyTorch PR, torchtitan PR shared by @XilunWu)
🚧 Composability with fp8
🚧 Multimodal models
Test plan
Sweep script to test a bunch of different configurations across Llama3 8B and Llama 3.2 1B models.
WandB project with all the runs
Also used this script to generate a synthetic dataset of concatenated Alpaca samples with most samples at ~1M context length.
Using that synthetic dataset, the screenshot above can be reproduced by running