Skip to content

[roadmap/tracker] Low precision training for MoEs #2147

Open
@danielvegamyhre

Description

@danielvegamyhre

Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.

Compute

Communication

I looked at traces and validated "all to all dispatch -> grouped gemm -> all to all combine" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.

This is also inspired by the DeepSeekV3 paper where, if I understand correctly, they do a2a dispatch in fp8 but keep a2a combine in bf16 as they found it was more sensitive to low precision during training.

  • Add on device all_to_all_v kernels compatible with:
    • float8 tensors with tensorwise scales (easiest).
    • float8 tensors with rowwise scales (harder).
    • When permuting token groups to be in the same order as experts (prior to the scaled grouped mm), reorder scales accordingly.

Torchao UX

  • JaggedFloat8Tensor (name TBD) with an op override for torch.aten._grouped_mm => runs differentiable scaled grouped mm
  • One line conversion API, either integrated into convert_to_float8_training or a standalone one. TBD. Swaps nn.Parameter data tensors of the expert weights with JaggedFloat8Tensors.

Compile support

  • Differentiable _scaled_grouped_mm can compile with fullgraph=True
  • E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach

Distributed support

  • Composability with FSDP2 (will likely need something like this for the new tensor subclass)
  • Composability with tp (will likely need something like these sharding primitives for the new tensor subclass)
  • Composability with tp2ep
  • Composability with dp2ep

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions