Skip to content

[float8 moe training] fix bug affecting mixed precision training #2451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2025

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Jun 26, 2025

Fixes pytorch/torchtitan#1332

Summary

Previously when testing the MoE training prototype manually with Llama4 in torchtitan, I had to add model = model.to(torch.bfloat16) in torchtitan's train.py to workaround an error resulting from what I thought was a bug in torchtitan not applying bf16 mixed precision policy for FDSP2 to MoE layers properly:

    File "/home/danvm/torchtitan/torchtitan/experiments/llama4/model/moe.py", line 83, in forward
      x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16
  AssertionError: torch._grouped_mm only supports bf16 dtypes

However, after debugging further I determined the bug was actually in the torchao ScaledGroupedMMTensor subclass not setting the dtype properly, the way WeightWithDynamicFloat8CastTensor does, where for ops that preserve the subclass, dispatch returns a ScaledGroupedMMTensor using dtype pulled while unwrapping the args/kwargs.

Reasoning: as a general rule for pytorch, the dtype of the output will be the same as one of its inputs (with the exception of a handful of ops with the 'out_dtype' arg). So we are enforcing a simple version of this to ensure weights are cast to the mixed precision dtype properly in FSDP.

Other

  • Add some logging
  • Add helper script for running FSDP tests which require torchrun w/ some args

Test plan

  • torchao tests: ./test/prototype/moe_training/test_fsdp.sh
  • torchtitan manual testing: NGPU=2 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --model.converters="float8" --float8.moe_fqns_prototype="experts"

@danielvegamyhre danielvegamyhre added float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes labels Jun 26, 2025
Copy link

pytorch-bot bot commented Jun 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2451

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 6 Pending

As of commit 4675cda with merge base 994a4ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2025
@danielvegamyhre
Copy link
Contributor Author

cc @drisspg @vkuzo for review

@danielvegamyhre danielvegamyhre requested review from vkuzo and drisspg June 26, 2025 23:52
@danielvegamyhre danielvegamyhre changed the title [float8 moe training] fix float8 moe training dtype bug [float8 moe training] fix float8 moe training dtype bug for mixed precision training Jun 26, 2025
@danielvegamyhre danielvegamyhre changed the title [float8 moe training] fix float8 moe training dtype bug for mixed precision training [float8 moe training] fix bug affecting mixed precision training Jun 26, 2025
@danielvegamyhre danielvegamyhre merged commit ac14d92 into main Jun 27, 2025
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. float8 topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama4 training does not automatically use bfloat16 when FSDP2 is enabled
3 participants