-
Notifications
You must be signed in to change notification settings - Fork 320
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Summary
- When applying low precision MoE training (fp8 rowwise in this case) there is a bug with the ScaledGroupedMMTensor subclass I believe. Specifically, we see fsdp2 complain about the all gather inputs shape in the error:
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 314, in unshard
self._all_gather_result = foreach_all_gather(
~~~~~~~~~~~~~~~~~~^
self.fsdp_params,
^^^^^^^^^^^^^^^^^
...<4 lines>...
self._all_gather_comm,
^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 248, in foreach_all_gather
param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 327, in _get_param_all_gather_inputs
param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 727, in all_gather_inputs
raise AssertionError(
...<4 lines>...
)
AssertionError: When a parameter is unevenly sharded by FSDP (orig size=torch.Size([1, 8192, 5120]), FSDP world size=8), fsdp_pre_all_gather must return all-gather inputs with the padded sharded size torch.Size([1, 8192, 5120]) but got [torch.Size([0, 8192, 5120])]
However, I added some debug logging to the tensor subclass, and it shows the all gather inputs and outputs both have the correct shapes. So I am a not yet clear about where the disconnect is.
Training logs:
[rank0]:[titan] 2025-08-02 17:59:08,752 - root - INFO - Training starts at step 1
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 5120, 8192])
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 8192, 5120]), inner_tensors.shape=torch.Size([8, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 5120, 8192]), inner_tensors.shape=torch.Size([8, 5120, 8192])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 8192, 5120]), inner_tensors.shape=torch.Size([8, 8192, 5120])
Repro command
- From torchtitan root:
NGPU=8 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=50 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="experts,shared_expert"
Optionally, you could just target "shared_expert" instead of both "experts,shared_expert" to simplify the repro further.
weifengpy
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working