[float8 moe training] fix bug affecting mixed precision training #2451
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes pytorch/torchtitan#1332
Summary
Previously when testing the MoE training prototype manually with Llama4 in torchtitan, I had to add
model = model.to(torch.bfloat16)
in torchtitan'strain.py
to workaround an error resulting from what I thought was a bug in torchtitan not applying bf16 mixed precision policy for FDSP2 to MoE layers properly:However, after debugging further I determined the bug was actually in the torchao ScaledGroupedMMTensor subclass not setting the dtype properly, the way WeightWithDynamicFloat8CastTensor does, where for ops that preserve the subclass, dispatch returns a ScaledGroupedMMTensor using dtype pulled while unwrapping the args/kwargs.
Reasoning: as a general rule for pytorch, the dtype of the output will be the same as one of its inputs (with the exception of a handful of ops with the 'out_dtype' arg). So we are enforcing a simple version of this to ensure weights are cast to the mixed precision dtype properly in FSDP.
Other
Test plan
./test/prototype/moe_training/test_fsdp.sh
NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --model.converters="float8" --float8.moe_fqns_prototype="experts"