-
Notifications
You must be signed in to change notification settings - Fork 610
Labels
bugSomething isn't workingSomething isn't working
Description
Bug description
This is a new error reported by @mreso and I've reproduced it. This does NOT happen with AC=none.
Error
out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
)
File "/home/danvm/.conda/envs/torch/lib/python3.13/site-packages/torch/utils/checkpoint.py", line 1090, in pack_hook
raise CheckpointError(
...<3 lines>...
)
torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: trying to save more tensors during recomputation than during the original forward pass.
Repro command
Change number of layers to 2 here (to avoid OOM):
Then run the command:
NGPU=4 CONFIG_FILE=/home/danvm/torchtitan/torchtitan/models/llama4/train_configs/llama4_17bx16e.toml ./run_train.sh \
--metrics.log_freq=10 \
--training.steps=100 \
--parallelism.data_parallel_shard_degree=4 \
--parallelism.expert_parallel_degree=4 \
--parallelism.tensor_parallel_degree=1 \
--parallelism.expert_tensor_parallel_degree=1 \
--training.seq_len=8192 \
--training.local_batch_size=4 \
--model.print_after_conversion \
--model.converters="quantize.grouped_mm.mx" \
--quantize.grouped_mm.mx.fqns="experts" \
--compile.enable \
--activation_checkpoint.mode="full" cc @soulitzer @xmfan
Versions
- torch latest nightly
- torchao latest main
- torchtitan latest main
tianyu-l
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working