Skip to content

[moe training] fsdp2 bug for llama4 shared experts where num_experts=1 #2673

@danielvegamyhre

Description

@danielvegamyhre

Summary

  • When applying low precision MoE training (fp8 rowwise in this case) there is a bug with the ScaledGroupedMMTensor subclass I believe. Specifically, we see fsdp2 complain about the all gather inputs shape in the error:
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 314, in unshard
      self._all_gather_result = foreach_all_gather(
                                ~~~~~~~~~~~~~~~~~~^
          self.fsdp_params,
          ^^^^^^^^^^^^^^^^^
      ...<4 lines>...
          self._all_gather_comm,
          ^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 248, in foreach_all_gather
      param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
      return func(*args, **kwargs)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py", line 327, in _get_param_all_gather_inputs
      param_all_gather_inputs[i] = fsdp_param.all_gather_inputs
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 727, in all_gather_inputs
      raise AssertionError(
      ...<4 lines>...
      )
  AssertionError: When a parameter is unevenly sharded by FSDP (orig size=torch.Size([1, 8192, 5120]), FSDP world size=8), fsdp_pre_all_gather must return all-gather inputs with the padded sharded size torch.Size([1, 8192, 5120]) but got [torch.Size([0, 8192, 5120])]

However, I added some debug logging to the tensor subclass, and it shows the all gather inputs and outputs both have the correct shapes. So I am a not yet clear about where the disconnect is.

Training logs:

[rank0]:[titan] 2025-08-02 17:59:08,752 - root - INFO - Training starts at step 1
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 5120, 8192])
[rank0]:[titan] 2025-08-02 17:59:09,148 - torchao.prototype.moe_training.tensor - INFO - fsdp all_gather_inputs.shape = torch.Size([1, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 8192, 5120]), inner_tensors.shape=torch.Size([8, 8192, 5120])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 5120, 8192]), inner_tensors.shape=torch.Size([8, 5120, 8192])
[rank0]:[titan] 2025-08-02 17:59:09,385 - torchao.prototype.moe_training.tensor - INFO - fsdp_post_all_gather: output.shape=torch.Size([8, 8192, 5120]), inner_tensors.shape=torch.Size([8, 8192, 5120])

Repro command

  • From torchtitan root: NGPU=8 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=50 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="experts,shared_expert"

Optionally, you could just target "shared_expert" instead of both "experts,shared_expert" to simplify the repro further.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions