Description
Creating this issue as a roadmap/tracker for enabling float8 training for MoEs with token-choice routing. Both core requirements as well as ideas for additional performance optimizations are included.
Compute
- Add torch._scaled_grouped_mm kernel in core (done by Natalia already)
- Add differentiable scaled grouped mm with dynamic float8 rowwise quant in torchao (done in Initial prototype of differentiable _scaled_grouped_mm function #1969)
- Add custom kernels in torchao for performing per-group scaling on device, to avoid host-device sync (done in [scaled grouped mm] add triton kernels for float8 rowwise quantization with per-group/jagged scales #2064 and [scaled grouped mm] integrate triton kernels into differentiable scaled grouped mm #2077)
- Fuse padding of group sizes up to nearest multiple of 16 into the dynamic quant kernel - this will improve perf and usability (since the caller can pass in raw token groups with doing the padding logic themselves. this will make migration easier).
Communication
I looked at traces and validated "all to all dispatch -> grouped gemm -> all to all combine" are all sequentially dependent, so in theory faster/low precision comms should improve performance. There is some overlap with the shared expert computation, but it is not 100% overlap, so there is room for optimization. This will be especially important if/when "all to all" spans multiple nodes, where inter-node network bandwidth is lower than the intra-node NVLink bandwidth.
This is also inspired by the DeepSeekV3 paper where, if I understand correctly, they do a2a dispatch in fp8 but keep a2a combine in bf16 as they found it was more sensitive to low precision during training.
- Add on device all_to_all_v kernels compatible with:
- float8 tensors with tensorwise scales (easiest).
- float8 tensors with rowwise scales (harder).
- When permuting token groups to be in the same order as experts (prior to the scaled grouped mm), reorder scales accordingly.
Torchao UX
- JaggedFloat8Tensor (name TBD) with an op override for
torch.aten._grouped_mm
=> runs differentiable scaled grouped mm - One line conversion API, either integrated into
convert_to_float8_training
or a standalone one. TBD. Swaps nn.Parameter data tensors of the expert weights with JaggedFloat8Tensors.
Compile support
- Differentiable _scaled_grouped_mm can compile with
fullgraph=True
- E2E compilation of each TranformerBlock in torchtitan after MoE conversion via tensor subclass approach
Distributed support
- Composability with FSDP2 (will likely need something like this for the new tensor subclass)
- Composability with tp (will likely need something like these sharding primitives for the new tensor subclass)
- Composability with tp2ep
- Composability with dp2ep