Skip to content

AC + compile interaction bug with MXFP8 MoE Training #1971

@danielvegamyhre

Description

@danielvegamyhre

Bug description

This is a new error reported by @mreso and I've reproduced it. This does NOT happen with AC=none.

Error

      out = out.scatter_add(
          dim=0, index=token_indices_experts_sorted, src=routed_output
      )
    File "/home/danvm/.conda/envs/torch/lib/python3.13/site-packages/torch/utils/checkpoint.py", line 1090, in pack_hook
      raise CheckpointError(
      ...<3 lines>...
      )
  torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: trying to save more tensors during recomputation than during the original forward pass.
  

Repro command

Change number of layers to 2 here (to avoid OOM):

Then run the command:

NGPU=4 CONFIG_FILE=/home/danvm/torchtitan/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml ./run_train.sh \
--metrics.log_freq=10 \
--training.steps=100  \
--parallelism.data_parallel_shard_degree=4 \
--parallelism.expert_parallel_degree=4 \
--parallelism.tensor_parallel_degree=1 \
--parallelism.expert_tensor_parallel_degree=1 \
--training.seq_len=8192 \
--training.local_batch_size=4 \
--model.print_after_conversion \
--model.converters="quantize.grouped_mm.mx" \
--quantize.grouped_mm.mx.fqns="experts" \
--compile.enable \
--activation_checkpoint.mode="full" 

cc @soulitzer @xmfan

Versions

  • torch latest nightly
  • torchao latest main
  • torchtitan latest main

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions