Skip to content

1M+ context length (context parallel integration) #2668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 30, 2025

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented May 2, 2025

Screenshot 2025-05-28 at 4 38 57 PM

Initial implementation of context parallelism in torchtune. We add a utility get_context_parallel_context to return a context manager whenever context parallelism is enabled (otherwise it's basically nullcontext). We also update our collate function to pad to 2 * context_parallel_dim (needed due to this).

Tested

  • CP only
  • CP + TP
  • CP + DP shard
  • CP + DP replicate
  • Composability with activation checkpointing + offloading
  • Composability with optimizer in backward

Still needs work (i.e. not addressed in this PR)

🚧 Composability with flex attention (PyTorch PR, torchtitan PR shared by @XilunWu)
🚧 Composability with fp8
🚧 Multimodal models

Test plan

Sweep script to test a bunch of different configurations across Llama3 8B and Llama 3.2 1B models.

WandB project with all the runs

Also used this script to generate a synthetic dataset of concatenated Alpaca samples with most samples at ~1M context length.

Using that synthetic dataset, the screenshot above can be reproduced by running

tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2/1B_full \
context_parallel_dim=8 batch_size=1 enable_activation_checkpointing=True \
enable_activation_offloading=True loss.num_output_chunks=128 \
model._component_=torchtune.models.llama3_2.llama3_2 model.vocab_size=128_256 \
model.num_layers=16 model.num_heads=32 model.num_kv_heads=8 model.embed_dim=2048 \
model.intermediate_dim=8192 model.max_seq_len=10_000_000 gradient_accumulation_steps=1 \
dataset._component_=torchtune.datasets.text_completion_dataset \
dataset.source=/tmp/long_alpaca_dataset/train max_steps_per_epoch=50

Copy link

pytorch-bot bot commented May 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 93085e3 with merge base 0d90675 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 2, 2025
@@ -718,3 +743,108 @@ def prepare_mha_for_tp(
if is_fusion_model:
model.decoder = decoder
return model


def _get_sdpa_context() -> (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean CP doesn't work with FlexAttention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, at least until pytorch/pytorch#151497 lands

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I also think this is somewhat orthogonal. Like flex does not have its own backend (see here). My assumption is that it should be using the flash attention backend (but need to confirm)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if torchtune already has non-CP flex_attention? cc @ebsmothers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XilunWu yes we do, see e.g. here for how we apply it for sample packing

# Define optional context manager for context parallelism
model_inputs = list(batch.values())
buffers = list(self._model.buffers())
optional_context_parallel_context_manager = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the naming we're using for other optional ctx managers? We have "activations_handling_ctx", though I'd prefer to consolidate on something like "context_parallel" or "maybe_context_parallel", I think the "with" statement says it's a context manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm good taking out the "optional" here and matching what we do for activation offloading

@joecummings joecummings mentioned this pull request Mar 30, 2025
4 tasks
@ebsmothers ebsmothers changed the title [wip] context parallelism 1M+ context length (context parallel integration) May 28, 2025
@ebsmothers ebsmothers marked this pull request as ready for review May 28, 2025 23:31
@codecov-commenter
Copy link

codecov-commenter commented May 28, 2025

Codecov Report

Attention: Patch coverage is 14.86486% with 63 lines in your changes missing coverage. Please review.

Project coverage is 60.02%. Comparing base (0d90675) to head (2c7ad46).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 20.00% 40 Missing ⚠️
recipes/full_finetune_distributed.py 0.00% 13 Missing ⚠️
recipes/qat_distributed.py 0.00% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2668       +/-   ##
==========================================
+ Coverage   7.75%   60.02%   +52.27%     
==========================================
  Files        376      437       +61     
  Lines      23117    26765     +3648     
==========================================
+ Hits        1792    16066    +14274     
+ Misses     21325    10699    -10626     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for enabling this! Left some feedback about how the manager is structured, and some questions on whether we're safely handling all of our edge cases.


def get_context_parallel_context(
*,
cp_enabled: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should just be "enabled", cp seems redundant in this function

# TODO: context parallel for multimodal models requires extra work
if cp_enabled and not isinstance(model, TransformerDecoder):
raise ValueError(
"Context parallel not supported for models other than TransformerDecoder"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this error message should just state that "Only text models are supported" as a user might not understand why they're getting this error.

# Create context parallel context if enabled
cp_context = None
if (
cp_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just return a NullContext if not enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that too, but this gives us the option to explicitly order the SDPBackends as we prefer them, so personally I think it's useful.

cp_context = None
if (
cp_enabled
and world_mesh is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

world_mesh and model_inputs are not optional inputs, do we need this check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will remove

# Create and enter the train context with the optional cp_context
sdpa_context = _get_sdpa_context()

with sdpa_context(cp_context):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens with llama4? Or with packed data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest updates should address this. For packing we don't know until forward, so unfortunately have to check inside of context. For Llama4 we check straightaway whether any layer's mask_mod is not None (though also we will error anyways since it's an early fusion model).

@@ -11,6 +11,7 @@
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
gather_cpu_state_dict,
get_context_parallel_context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_context_parallel_manager would read nicer

collate_fn=(
partial(
collate_fn,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
pad_to_multiple_of=self.tp_degree,
pad_to_multiple_of=self.tp_degree * self.cp_degree * 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this value could come from parallel dims? Like a property?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a good idea

@@ -911,6 +919,16 @@ def train(self) -> None:

utils.batch_to_device(batch, self._device)

# Define optional context manager for context parallelism
context_parallel_context_manager = (
training.get_context_parallel_context(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the recipe setup, we should call self.cp_manager = training.get_context_parallel_manager(enabled, mesh, model). Then during training, we can just do with self.cp_manager(batch):

with sdpa_context(cp_context):
yield

return context()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return a context manager function uninitialized that is expecting data as an input, so you can call with context(batch): in your train loop

cp_enabled: bool = False,
world_mesh: torch.distributed.DeviceMesh,
model: TransformerDecoder,
model_inputs: list[torch.Tensor],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should take a dict, since all our recipes use a nested dict for inputs. Also, you can't just call list(batch.values) since there are nested dicts. Since we don't support multimodal for now, you can just explicitly check that tokens are inside the batch in this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested dicts only happen with multimodal, right? Actually doesn't TensorDict have some utilities for some stuff like this 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, I am gonna leave this as is for now since we are already checking that we are in a text-only regime when initializing the context manager.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping to unblock but please fix the comments

@@ -158,6 +158,7 @@ def __init__(self, cfg: DictConfig) -> None:
raise ValueError(
"Tensor Parallel plan needs to be provided when tensor parallel is enabled."
)
self.cp_degree = cfg.get("context_parallel_dim", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you put this as a default in one of the full finetune distributed recipes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly idk if it makes sense. Like for Alpaca without packing there's really no reason to cause the sequences are so short

@@ -603,6 +605,10 @@ def _setup_model(
"FP8 training does not support tensor parallelism yet. "
"This will be enabled in the near future."
)
if self.cp_degree > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: access this through parallel_dims

torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss * (self.dp_degree / num_tokens)
with self.context_parallel_manager(list(batch.values())):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to do this? Couldn't you have the manager take in any iterable, including a generator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess? But list is a more natural datatype, no? And also that's ultimately what core's context manager expects anyways so I would prefer to stay close to that

@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import List
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. bad. use list.

@@ -15,12 +16,14 @@
the llama3_2_1b model builder uses the llama3_2 component builder to create the
Llama3.2 1B model.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idk i think it's just lint adding missing spaces between functions (idk why they weren't there to begin with)

@@ -67,6 +72,8 @@ def llama3_2_3b(
scale_factor=32,
tie_word_embeddings=tie_word_embeddings,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonsense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these clearly should have been here already, no?

return mesh

@property
def enabled(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be cp_enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blargh i did find and replace all in the file, sorry. will fix

enabled=parallel_dims.enabled,
cp_mesh=world_mesh["cp"] if parallel_dims.enabled else None,
model_inputs=list(batch.values()),
model_buffers=model.buffers(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this example?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agh I need to update this. Do you not find the example informative? If not I can take it out

@ebsmothers ebsmothers merged commit 4309419 into pytorch:main May 30, 2025
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants