Skip to content

[torch.compile][Performance]: Unwrap custom ops and improve fusion (Inductor and custom) #24629

@ProExpertProg

Description

@ProExpertProg

Proposal to improve performance

This issue tracks work on unwrapping custom ops and enabling additional fusions inside the model forward pass.

Motivation

While profiling Deepseek, GPT-OSS, and LLaMa (in-progress), we noticed a lot of overhead from elementwise & unfused ops. In most cases, those ops could be fused or eliminated by torch.compile but they happen inside a custom-op, hidden from torch.compile. There are some fusion opportunities across a current custom op boundary and some under-utilized fusion opportunities completely unrelated to custom ops.

Example trace (Deepseek):

Image

How to read this trace:

  • nvjet_... are unquantized torch.mm calls
  • native::elementwise are elementwise ops not visible to torch.compile
  • triton_red_fused_... are (fused) elementwise ops visible to torch.compile and lowered into fused Triton kernels
  • all other kernels are custom kernels (attn, GEMM, MoE, allreduce, quant, etc.)

The fusion opportunities are below:

Image

Action items

  • RMSNorm + quant (fusions 1 & 2):
    1. Unwrap apply_w8a8_block_fp8_linear
    2. use QuantFP8 for block quantization and rely on Inductor fusion
    3. (bonus) add support for group quantization to the fused rms_norm_quant kernel (TODO: good first issue)
  • Allreduce + RMSNorm + quant (fusion 1):
    1. Unwrap apply_w8a8_block_fp8_linear
    2. Try to use existing flashinfer kernel (needs group quant support)
  • RoPE + quant (+ cache) (fusion 3)
    1. Produce single Triton kernel for RoPE:
    • should we do a Helion kernel instead of relying on Inductor heuristics?
    1. Add input quant de-fusion: pull quant out of attention using pass
    2. Refactor attention layer to pull cache out of custom op
  • attn point ops (fusions 3 & 4):
    1. Remove unnecessary copies from Cutlass MLA
    2. Add torch.compile inside attention op (likely needs to handle prefill-decode split)
    3. Unwrap from custom op post MLA layer refactor
    4. Could extract prefill-decode split from op?
    • what happens to piecewise cudagraphs
  • attn + quant (fusion 4):
    1. Kernel support for group output quant for flash attn (MLA prefill, also MQA)
    2. Kernel support for group output quant for v-proj bmm (MLA decode)
  • GEMM + quant (fusion 6):
    1. Unwrap apply_w8a8_block_fp8_linear
    2. Unwrap fused_moe to get quant out
    3. New fused output quant GEMM kernel (unquantized)
  • point ops + quant (fusions 7 & 8)
    1. Unwrap apply_w8a8_block_fp8_linear
    2. Unwrap fused_moe or torch.compile inside custom op
  • router/normalize/finalize + other (fusions 8 & 9):
    1. Need to understand if any of these are fusable
    2. Custom kernels to fuse

The best course of action is to start with fusions that are easier to enable and proceed down the list:

  1. Remove unnecessary copies from Cutlass MLA ([Performance] Remove redundant clone() calls in cutlass_mla #24891 and [Performance] Remove input pads in cutlass_mla and optimize v_proj output handling #25184)
  2. Unwrap apply_w8a8_block_fp8_linear (use similar abstraction to Fp8LinearOp for dispatching): [Perf] Fix and reapply move apply w8a8 block fp8 linear to class #25696
  3. Fuse padding onto GEMM by making the GEMM out-of-place
  4. Use QuantFP8 for group quant ([FP8] Extend per-token-group quantization support to QuantFP8 #24342)
  5. FlashInfer all_reduce+rms_norm+group_quant kernel for Blackwell
  6. RoPE+quant: pull quant out of attention with pass
  7. Refactor attn layer for MLA and MHA and extract ops from unified_attention custom op: [Refactor]: Make an common MLAAttention Layer and custom OP #24620
  8. Fuse RoPE, quant, cache: [Performance]: ROPE + KV-Cache-Write + pre-attn prepare-ops fusion #24678
  9. Unwrap fused_moe

Later, larger, custom kernel undertakings:

  1. Kernel support for group output quant for flash attn (could use existing FI instead of FA for MLA prefill, still needs dynamic group quant)
  2. Fuse custom MoE kernels (routing/normalize/finalize)
  3. New fused output quant BMM kernel (unquantized)
  4. New fused output quant GEMM kernel (unquantized)

If we're not satisfied with performance of Inductor-generated Triton, we could investigate custom Helion kernels. They are almost as easy to write as torch eager, always fused, and can automatically be fused with any other torch ops before and after. Two good examples would be RoPE and fp8 group quant.

Sub-issues

Metadata

Metadata

Assignees

Type

No type

Projects

Status

In progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions