|
21 | 21 | SequenceParallel, |
22 | 22 | ) |
23 | 23 | from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
24 | | -from torchtitan.config.job_config import Compile as CompileConfig |
| 24 | +from torchtitan.config.job_config import ( |
| 25 | + ActivationCheckpoint as ACConfig, |
| 26 | + Compile as CompileConfig, |
| 27 | +) |
25 | 28 | from torchtitan.distributed import NoParallel, ParallelDims |
26 | 29 | from torchtitan.distributed.activation_checkpoint import apply_ac |
27 | 30 |
|
@@ -129,7 +132,7 @@ def parallelize_llama( |
129 | 132 |
|
130 | 133 | # turn on per-TransformerBlock compile after AC wrapping and before FSDP |
131 | 134 | if model_compile_enabled: |
132 | | - apply_compile(model, job_config.compile) |
| 135 | + apply_compile(model, job_config.compile, job_config.activation_checkpoint) |
133 | 136 |
|
134 | 137 | dp_mesh: DeviceMesh | None = None |
135 | 138 | if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: |
@@ -506,11 +509,19 @@ def apply_moe_ep_tp( |
506 | 509 | ) |
507 | 510 |
|
508 | 511 |
|
509 | | -def apply_compile(model: nn.Module, compile_config: CompileConfig): |
| 512 | +def apply_compile(model: nn.Module, compile_config: CompileConfig, ac_config: ACConfig): |
510 | 513 | """ |
511 | 514 | Apply torch.compile to each TransformerBlock, which makes compilation efficient due to |
512 | 515 | repeated structure. Alternatively one can compile the whole model (after applying DP). |
513 | 516 | """ |
| 517 | + |
| 518 | + if ac_config.mode == "selective": |
| 519 | + logger.warning( |
| 520 | + "Compile + Selective Activation Checkpointing is not yet supported for MoE models, " |
| 521 | + "please use Full Activation Checkpointing instead. Turning off Compile." |
| 522 | + ) |
| 523 | + return |
| 524 | + |
514 | 525 | # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE |
515 | 526 | # but it is experimental. |
516 | 527 | torch._dynamo.config.capture_scalar_outputs = True |
|
0 commit comments