[WIP] Merging AutoSP into DeepSpeed#7860
[WIP] Merging AutoSP into DeepSpeed#7860neeldani wants to merge 6 commits intodeepspeedai:masterfrom
Conversation
|
Hi @neeldani, Since this is a large PR, let’s proceed step by step. Here are my suggestions:
|
|
@tohtana thank you for the feedback:
Please let me know if there are any hiccups to run AutoSP or have any questions related to the design - happy to discuss them on this PR. |
| @@ -0,0 +1,64 @@ | |||
| import torch | |||
There was a problem hiding this comment.
(Not about this file) Don't we need __init__.py in custom_ops?
| ######################################### | ||
| # AUTOSP | ||
| ######################################### | ||
| INPUT_ID_KEY = "input_id" |
There was a problem hiding this comment.
Can we make these AutoSP specific? like AUTOSP_*
| if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ | ||
| and self.zero_optimization_stage() != ZeroStageEnum.weights \ | ||
| and self.zero_optimization_stage() != ZeroStageEnum.gradients: | ||
| and self.zero_optimization_stage() != ZeroStageEnum.gradients \ |
There was a problem hiding this comment.
Can you elaborate the intension of this change? (the preexisting condition also seems weird, though)
Assuming zero_optimization_stage returns 0, 1, 2, 3, we never enter this block?
Maybe we need the fallback only when sp is disabled and zero stage is 0?
| "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " | ||
| "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") | ||
| backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) | ||
| elif self.zero_optimization_stage() == ZeroStageEnum.disabled: |
There was a problem hiding this comment.
Currently, do we enable AutoSP by these?
- set zero stage to 0
- Enable deepcompile
If so, I think we should make it more explicit.
For AutoTP (see example), we do
"tensor_parallel": {
"autotp_size": args.tp_size,
...
}
For AutoEP proposal,
"expert_parallel": {
"autoep_size": args.ep_size,
...
}
AutoEP is currently just a proposal, but how about making the config
"sequence_parallel": {
"autosp_size": args.sp_size,
...
}
You may want to require DeepCompile to be enabled too. As we don't have eager AutoSP now, it might be good to enable DeepCompile automatically when sequence_parallel is enabled.
|
Thank you @neeldani for the update! As we don't have a lot of changes in existing code, I don't think we have much risk. We should also have clear assertions to terminate early when we hit these limitations: Attention Pattern Matching and No Graph Break Requirement. |
AutoSP: Unlocking Long-Context LLM Training Via Compiler-Based Sequence Parallelism
Overview
AutoSP is a compiler optimization pass that shards inputs along the sequence dimension and enables Ulysses styled sequence parallelism while preventing graph breaks during
torch.compile(). All the passes operate at the Torch IR on the forward graph.API Design
User-Facing Entry Point:
prepare_autosp_inputs()Users must explicitly call this function to prepare inputs for AutoSP compilation:
Purpose: Symbolize sequence dimension and annotate tensors for identification.
Operations:
torch._dynamo.decorators.mark_dynamic()input_id.tag = constants.INPUT_ID_KEYlabel_id.tag = constants.LABEL_ID_KEYposition_id.tag = constants.POSITION_ID_KEY(if provided)Rationale: PyTorch's FX graph tracer requires explicit annotation of data-dependent dimensions. Marking the sequence dimension as dynamic prevents symbolic shape propagation from losing dimension information through reshape/view operations.
Compilation Passes
Pass 1:
pass_shard_seq_dim()Objective: Propagate sharded sequence dimension to all consumers.
Algorithm:
input_idshape metadataseq_dim / world_sizeRationale: Reshapes and views that consume the sequence dimension as an argument do not get updated during propagation of symbolic shapes. This pass explicitly rewires the computation graph to use sharded dimensions, enabling proper shape inference downstream.
Pass 2:
pass_shard_input_ids()/pass_shard_label_ids()/pass_shard_position_ids()Objective: Insert slicing operations after input tensors.
Implementation: Call
shard_tensor_node()utility which inserts slice operations. Each rank retains only the portion of the tensor corresponding to its sequence partition and drops the remaining buffer.Note on
attention_mask: Not sharded because it applies to the full sequence length, not the partitioned dimension.Pass 3:
pass_insert_attention_all_to_all()Objective: Insert all-to-all collectives around attention (Ulysses styled) to avoid graph breaks during compilation.
Algorithm:
Graph Rewrite Example:
Current support: Currently only supports
torch.nn.functional.scaled_dot_product_attention(). Composite attention patterns require additional pattern matching logic.Pass 4:
pass_propagate_shapes()Objective: Compute static shapes for all nodes using fake tensor execution.
Implementation:
ShapeEnvfor symbolic dimension trackingFakeTensorModewith the shape environmentFakeTensorProp.propagate()to compute shape metadataPass 5:
pass_canonicalize()Objective: Finalize graph representation.
Operations:
eliminate_dead_code(): Remove unused operationslint(): Validate graph structurerecompile(): Regenerate compiled representationExecution Order
Reducing gradients across ranks
AutoSP requires an all-reduce to reduce the gradients across ranks. This is automatically called by DeepSpeed's engine here
Known Limitations
torch.nn.functional.scaled_dot_product_attention()is supported. Fused attention implementations require pattern-specific handling.Example
DeepSpeedExample PR: deepspeedai/DeepSpeedExamples#999