- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
Description
Proposal to improve performance
This issue tracks work on unwrapping custom ops and enabling additional fusions inside the model forward pass.
Motivation
While profiling Deepseek, GPT-OSS, and LLaMa (in-progress), we noticed a lot of overhead from elementwise & unfused ops. In most cases, those ops could be fused or eliminated by torch.compile but they happen inside a custom-op, hidden from torch.compile. There are some fusion opportunities across a current custom op boundary and some under-utilized fusion opportunities completely unrelated to custom ops.
Example trace (Deepseek):
 
How to read this trace:
- nvjet_...are unquantized- torch.mmcalls
- native::elementwiseare elementwise ops not visible to- torch.compile
- triton_red_fused_...are (fused) elementwise ops visible to- torch.compileand lowered into fused Triton kernels
- all other kernels are custom kernels (attn, GEMM, MoE, allreduce, quant, etc.)
The fusion opportunities are below:
 
Action items
- RMSNorm + quant (fusions 1 & 2):
- Unwrap apply_w8a8_block_fp8_linear
- use QuantFP8 for block quantization and rely on Inductor fusion
- (bonus) add support for group quantization to the fused rms_norm_quant kernel (TODO: good first issue)
 
- Unwrap 
- Allreduce + RMSNorm + quant (fusion 1):
- Unwrap apply_w8a8_block_fp8_linear
- Try to use existing flashinfer kernel (needs group quant support)
 
- Unwrap 
- RoPE + quant (+ cache) (fusion 3)
- Produce single Triton kernel for RoPE:
 - should we do a Helion kernel instead of relying on Inductor heuristics?
 - Add input quant de-fusion: pull quant out of attention using pass
- Refactor attention layer to pull cache out of custom op
 
- attn point ops (fusions 3 & 4):
- Remove unnecessary copies from Cutlass MLA
- Add torch.compile inside attention op (likely needs to handle prefill-decode split)
- Unwrap from custom op post MLA layer refactor
- Could extract prefill-decode split from op?
 - what happens to piecewise cudagraphs
 
- attn + quant (fusion 4):
- Kernel support for group output quant for flash attn (MLA prefill, also MQA)
- Kernel support for group output quant for v-proj bmm (MLA decode)
 
- GEMM + quant (fusion 6):
- Unwrap apply_w8a8_block_fp8_linear
- Unwrap fused_moeto get quant out
- New fused output quant GEMM kernel (unquantized)
 
- Unwrap 
- point ops + quant (fusions 7 & 8)
- Unwrap apply_w8a8_block_fp8_linear
- Unwrap fused_moeor torch.compile inside custom op
 
- Unwrap 
- router/normalize/finalize + other (fusions 8 & 9):
- Need to understand if any of these are fusable
- Custom kernels to fuse
 
The best course of action is to start with fusions that are easier to enable and proceed down the list:
- Remove unnecessary copies from Cutlass MLA ([Performance] Remove redundant clone() calls in cutlass_mla #24891 and [Performance] Remove input pads in cutlass_mla and optimize v_proj output handling #25184)
- Unwrap apply_w8a8_block_fp8_linear(use similar abstraction toFp8LinearOpfor dispatching): [Perf] Fix and reapply move apply w8a8 block fp8 linear to class #25696
- Fuse padding onto GEMM by making the GEMM out-of-place
- Use QuantFP8 for group quant ([FP8] Extend per-token-group quantization support to QuantFP8 #24342)
- FlashInfer all_reduce+rms_norm+group_quant kernel for Blackwell
- RoPE+quant: pull quant out of attention with pass
- Refactor attn layer for MLA and MHA and extract ops from unified_attentioncustom op: [Refactor]: Make an common MLAAttention Layer and custom OP #24620
- Fuse RoPE, quant, cache: [Performance]: ROPE + KV-Cache-Write + pre-attn prepare-ops fusion #24678
- Unwrap fused_moe
Later, larger, custom kernel undertakings:
- Kernel support for group output quant for flash attn (could use existing FI instead of FA for MLA prefill, still needs dynamic group quant)
- Fuse custom MoE kernels (routing/normalize/finalize)
- New fused output quant BMM kernel (unquantized)
- New fused output quant GEMM kernel (unquantized)
If we're not satisfied with performance of Inductor-generated Triton, we could investigate custom Helion kernels. They are almost as easy to write as torch eager, always fused, and can automatically be fused with any other torch ops before and after. Two good examples would be RoPE and fp8 group quant.
Sub-issues
Metadata
Metadata
Assignees
Labels
Type
Projects
Status