Skip to content

[WIP] Merging AutoSP into DeepSpeed#7860

Draft
neeldani wants to merge 6 commits intodeepspeedai:masterfrom
neeldani:autosp
Draft

[WIP] Merging AutoSP into DeepSpeed#7860
neeldani wants to merge 6 commits intodeepspeedai:masterfrom
neeldani:autosp

Conversation

@neeldani
Copy link

@neeldani neeldani commented Feb 19, 2026

AutoSP: Unlocking Long-Context LLM Training Via Compiler-Based Sequence Parallelism

Overview

AutoSP is a compiler optimization pass that shards inputs along the sequence dimension and enables Ulysses styled sequence parallelism while preventing graph breaks during torch.compile(). All the passes operate at the Torch IR on the forward graph.

API Design

User-Facing Entry Point: prepare_autosp_inputs()

Users must explicitly call this function to prepare inputs for AutoSP compilation:

def prepare_autosp_inputs(
    input_id: torch.Tensor,
    label_id: torch.Tensor,
    position_id: torch.Tensor = None,
    attention_mask: torch.Tensor = None,
    seq_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

Purpose: Symbolize sequence dimension and annotate tensors for identification.

Operations:

  1. Mark sequence dimension as dynamic using torch._dynamo.decorators.mark_dynamic()
  2. Attach metadata tags for tensor identification for auto-sharding:
    • input_id.tag = constants.INPUT_ID_KEY
    • label_id.tag = constants.LABEL_ID_KEY
    • position_id.tag = constants.POSITION_ID_KEY (if provided)

Rationale: PyTorch's FX graph tracer requires explicit annotation of data-dependent dimensions. Marking the sequence dimension as dynamic prevents symbolic shape propagation from losing dimension information through reshape/view operations.

Compilation Passes

Pass 1: pass_shard_seq_dim()

Objective: Propagate sharded sequence dimension to all consumers.

Algorithm:

  1. Extract symbolic sequence dimension from input_id shape metadata
  2. Locate the symbolic dimension node in the FX graph
  3. Create a floor-divide node: seq_dim / world_size
  4. Perform worklist-based graph traversal to find all direct and indirect consumers of input node, label node and position id node.
  5. Replace symbolic dimension references with sharded dimension in consumer nodes

Rationale: Reshapes and views that consume the sequence dimension as an argument do not get updated during propagation of symbolic shapes. This pass explicitly rewires the computation graph to use sharded dimensions, enabling proper shape inference downstream.

Pass 2: pass_shard_input_ids() / pass_shard_label_ids() / pass_shard_position_ids()

Objective: Insert slicing operations after input tensors.

Implementation: Call shard_tensor_node() utility which inserts slice operations. Each rank retains only the portion of the tensor corresponding to its sequence partition and drops the remaining buffer.

Note on attention_mask: Not sharded because it applies to the full sequence length, not the partitioned dimension.

Pass 3: pass_insert_attention_all_to_all()

Objective: Insert all-to-all collectives around attention (Ulysses styled) to avoid graph breaks during compilation.

Algorithm:

  1. Identify all SDPA (Scaled Dot-Product Attention) nodes in the graph
  2. For each SDPA node with inputs Q, K, V, after each of Q, K, V: insert A2A scatter heads (dim=1), gather sequence (dim=2)
  3. Insert A2A after thre attention output O: scatter sequence (dim=2), gather heads (dim=1)

Graph Rewrite Example:

Q [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
K [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
V [B, N, S/P, H] --A2A(scatter_heads,gather_seq)--> [B, N/P, S, H]
                     |
                    SDPA
                     |
O [B, N/P, S, H] --A2A(scatter_seq,gather_heads)--> [B, N, S/P, H]

Current support: Currently only supports torch.nn.functional.scaled_dot_product_attention(). Composite attention patterns require additional pattern matching logic.

Pass 4: pass_propagate_shapes()

Objective: Compute static shapes for all nodes using fake tensor execution.

Implementation:

  1. Create ShapeEnv for symbolic dimension tracking
  2. Construct FakeTensorMode with the shape environment
  3. Execute FakeTensorProp.propagate() to compute shape metadata

Pass 5: pass_canonicalize()

Objective: Finalize graph representation.

Operations:

  1. eliminate_dead_code(): Remove unused operations
  2. lint(): Validate graph structure
  3. recompile(): Regenerate compiled representation

Execution Order

prepare_autosp_inputs()
    ↓
pass_shard_seq_dim
    ↓
pass_shard_input_ids
    ↓
pass_shard_label_ids
    ↓
pass_shard_position_ids
    ↓
pass_insert_attention_all_to_all
    ↓
pass_propagate_shapes
    ↓
pass_canonicalize

Reducing gradients across ranks

AutoSP requires an all-reduce to reduce the gradients across ranks. This is automatically called by DeepSpeed's engine here

Known Limitations

  1. Attention Pattern Matching: Only torch.nn.functional.scaled_dot_product_attention() is supported. Fused attention implementations require pattern-specific handling.
  2. No Graph Break Requirement: AutoSP will fail if there are graph breaks because use-def chains are lost and it becomes tricky to propagate auto-sharding information across graph modules.

Example

DeepSpeedExample PR: deepspeedai/DeepSpeedExamples#999

@neeldani neeldani changed the title [WIP] Merging AutoSP into Deepspeed [WIP] Merging AutoSP into DeepSpeed Feb 19, 2026
@tohtana
Copy link
Collaborator

tohtana commented Feb 20, 2026

Hi @neeldani,
Thank you for opening this PR! This is truly exciting.

Since this is a large PR, let’s proceed step by step. Here are my suggestions:

  • Code Location: This PR contains a significant amount of client code in bench_dc_ulysses. Could we move that to DeepSpeedExamples instead? Feel free to open a separate PR there for it.
  • Documentation: The README in bench_dc_ulysses appears to be outdated. Could you update it with instructions so we can reproduce the results?
  • API Design: Could you share the current API design? As you mentioned, we should discuss this further. You can either add the details to this PR or start a new Discussion in this repo.

@neeldani
Copy link
Author

@tohtana thank you for the feedback:

  1. Moved the scripts to DeepSpeedExample and put up a new PR: Add AutoSP example DeepSpeedExamples#999
  2. Updated the README.md with the setup instructions
  3. Updated the description of this PR with the API design

Please let me know if there are any hiccups to run AutoSP or have any questions related to the design - happy to discuss them on this PR.

@@ -0,0 +1,64 @@
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not about this file) Don't we need __init__.py in custom_ops?

#########################################
# AUTOSP
#########################################
INPUT_ID_KEY = "input_id"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these AutoSP specific? like AUTOSP_*

if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \
and self.zero_optimization_stage() != ZeroStageEnum.weights \
and self.zero_optimization_stage() != ZeroStageEnum.gradients:
and self.zero_optimization_stage() != ZeroStageEnum.gradients \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate the intension of this change? (the preexisting condition also seems weird, though)
Assuming zero_optimization_stage returns 0, 1, 2, 3, we never enter this block?
Maybe we need the fallback only when sp is disabled and zero stage is 0?

"DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. "
"Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.")
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
elif self.zero_optimization_stage() == ZeroStageEnum.disabled:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, do we enable AutoSP by these?

  • set zero stage to 0
  • Enable deepcompile
    If so, I think we should make it more explicit.
    For AutoTP (see example), we do
"tensor_parallel": {
   "autotp_size": args.tp_size,
...
}

For AutoEP proposal,

"expert_parallel": {
   "autoep_size": args.ep_size,
...
}

AutoEP is currently just a proposal, but how about making the config

"sequence_parallel": {
   "autosp_size": args.sp_size,
...
}

You may want to require DeepCompile to be enabled too. As we don't have eager AutoSP now, it might be good to enable DeepCompile automatically when sequence_parallel is enabled.

@tohtana
Copy link
Collaborator

tohtana commented Feb 25, 2026

Thank you @neeldani for the update! As we don't have a lot of changes in existing code, I don't think we have much risk.
The key discussion is the interface to enable AutoSP. See this comment. I also want to get thoughs from @sfc-gh-truwase @minjiazhang

We should also have clear assertions to terminate early when we hit these limitations: Attention Pattern Matching and No Graph Break Requirement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants