Skip to content

[roadmap/tracker] Low precision MoE training #2147

Open
@danielvegamyhre

Description

@danielvegamyhre

Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.

Compute

Communication

I looked at traces and validated "all to all dispatch -> grouped gemm -> all to all combine" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.

This is also inspired by the DeepSeekV3 paper where, if I understand correctly, they do a2a dispatch in fp8 but keep a2a combine in bf16 as they found it was more sensitive to low precision during training.

  • Add on device all_to_all_v kernels compatible with:
    • mxfp8 (P0)
    • float8 rowwise (P1)
  • token permutation kernel supports low precision dtypes by permuting scales to be in proper order for permuted tokens (link)
    • mxfp8 (P0)
    • float8 rowwise (P1)

Torchao UX

Compile support

Distributed support

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions