Description
Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.
Compute
- fp8 rowwise
- Add torch._scaled_grouped_mm kernel in core
- Add differentiable scaled grouped mm with dynamic float8 rowwise quant in torchao
- Add custom kernels in torchao for performing per-group scaling on device, to avoid host-device sync
- Fuse padding of group sizes up to nearest multiple of 16 into the dynamic quant kernel
- mxpf8
- mxfp8 scaled grouped gemm Add MXFP8 Support to scaled_grouped_gemm pytorch#153502
- torchao differentiable _scaled_grouped_mm support for mxpf8 recipe for dynamic quant before grouped GEMMs
- triton kernels to do scaling per group without d2h sync
Communication
I looked at traces and validated "all to all dispatch -> grouped gemm -> all to all combine" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.
This is also inspired by the DeepSeekV3 paper where, if I understand correctly, they do a2a dispatch in fp8 but keep a2a combine in bf16 as they found it was more sensitive to low precision during training.
- Add on device all_to_all_v kernels compatible with:
- mxfp8 (P0)
- float8 rowwise (P1)
- token permutation kernel supports low precision dtypes by permuting scales to be in proper order for permuted tokens (link)
- mxfp8 (P0)
- float8 rowwise (P1)
Torchao UX
- Add tensor subclass (ScaledGroupedMMTensor) with an op override for
torch.aten._grouped_mm
=> runs differentiable scaled grouped mm - Add one line model conversion API, should recursively swap nn.Parameter data tensors of the expert weights with ScaledGroupedMMTensor.
- support configurable recipe (fp8 rowwise, mxpf8)
Compile support
- Compile support for
torch._grouped_mm
- Differentiable _scaled_grouped_mm can compile with
fullgraph=True
- E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach
Distributed support
- Composability with FSDP2 (will likely need something like this for the new tensor subclass)
- mxfp8 (P0)
- float8 rowwise (P1) [float8 moe training] FSDP support #2413
- Composability with TP
- mxfp8 (P0)
- float8 rowwise [float8 moe training] Add TP support #2425
- Composability with FSDP + TP 2D parallel
- mxfp8 (P0)
- float8 rowwise (P1) [float8 moe training] Add TP support #2425
- Composability with tp2ep
- mxfp8 (P0)
- float8 rowwise (P1)
- Composability with dp2ep
- mxfp8 (P0)
- float8 rowwise (P1)