Skip to content

Commit 922085c

Browse files
committed
Warn that SAC + Compile for MoE models is not yet supported
stack-info: PR: #2052, branch: xmfan/stack/4
1 parent 22e959a commit 922085c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def parallelize_deepseekv3(
118118
)
119119

120120
if model_compile_enabled:
121-
apply_compile(model, job_config.compile)
121+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
122122

123123
dp_mesh: DeviceMesh | None = None
124124
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
SequenceParallel,
2222
)
2323
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
24-
from torchtitan.config.job_config import Compile as CompileConfig
24+
from torchtitan.config.job_config import (
25+
ActivationCheckpoint as ACConfig,
26+
Compile as CompileConfig,
27+
)
2528
from torchtitan.distributed import NoParallel, ParallelDims
2629
from torchtitan.distributed.activation_checkpoint import apply_ac
2730

@@ -129,7 +132,7 @@ def parallelize_llama(
129132

130133
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
131134
if model_compile_enabled:
132-
apply_compile(model, job_config.compile)
135+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
133136

134137
dp_mesh: DeviceMesh | None = None
135138
if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled:
@@ -506,11 +509,19 @@ def apply_moe_ep_tp(
506509
)
507510

508511

509-
def apply_compile(model: nn.Module, compile_config: CompileConfig):
512+
def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig):
510513
"""
511514
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
512515
repeated structure. Alternatively one can compile the whole model (after applying DP).
513516
"""
517+
518+
if ac_config.mode == "selective":
519+
logger.warning(
520+
"Selective Activation Checkpointing is not yet supported for MoE models, "
521+
"please use Full Activation Checkpointing instead. Turning off compile."
522+
)
523+
return
524+
514525
# NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE
515526
# but it is experimental.
516527
torch._dynamo.config.capture_scalar_outputs = True

torchtitan/models/qwen3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def parallelize_qwen3(
119119

120120
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
121121
if model_compile_enabled:
122-
apply_compile(model, job_config.compile)
122+
apply_compile(model, job_config.compile, job_config.activation_checkpoint)
123123

124124
if parallel_dims.fsdp_enabled:
125125
# apply FSDP or HSDP, potentially with Context Parallel

0 commit comments

Comments
 (0)