Description
🚀 Motivation
Currently, we do not have a consistent plan for "light" custom ops (subclasses of CustomOp
with both torch-native and GPU implementations). As we work on improved performance on NVIDIA Blackwell and AMD, we should be more intentional with CompilationConfig defaults that control custom op dispatching. This is a parent issue that tracks smaller PRs addressing CustomOp
s.
In vLLM, there are two kinds of custom kernels/ops:
- "heavy" ops like GEMMs, MoE, and attention, which will mostly use tuned custom kernels for maximum performance.
- "light" ops like
RMSNorm
,SiluAndMul
, andRoPE
, which have both torch-native and custom GPU implementations.
This issue only refers to "light" ops, which are (or should be) all subclasses of CustomOp
.
When we enabled torch.compile
by default in V1, the plan was to reduce our reliance on custom kernels to reduce maintenance costs and code complexity, even with minor performance costs. Recent versions of torch
actually produce Triton kernels faster than our custom op implementations anyway.
However, with startup time concerns (#19824), it seems like we want good performance even with Inductor disabled (more discussion on startup times to come in a follow-up issue). Additionally, custom op performance has been reported to be better than torch.compile-generated Triton kernels on AMD.
❗ Issues
This is a list of current issues with custom ops. The following section tracks proposed and WIP solutions. Larger line items might have their own issue or get one in the future.
- [Perf] FP8 quantization not fused with rms_norm/silu_mul (because fp8 quant doesn't have a torch native implementation)
- [Perf] AMD uses custom ops but fusion for them is hardcoded to disabled.
- [Perf] We don't have good visibility into performance differences between GPU and torch native implementations of custom ops across different models and hardware platforms.
- [Perf][Code Quality] Fused and unfused custom ops currently reimplement the same code, and only some are vectorized.
- [Testing] Custom op tests either don't exist or reimplement Torch naive implementations
- [Compilation] Current custom passes rely on custom op pattern matching, which involves auto-functionalization ([WIP][RFC]: Use auto-functionalization V2 in PyTorch 2.7+ #14703) and requires custom ops to be enabled. If custom ops are slower than native impls, that means e.g. attention+quant fusion needs to take a hit on other quants to fuse the o-proj one. I'll create an RFC for this.
✅ Solutions
Once we have benchmarking numbers, we can set sensible defaults. Improving startup time will likely result in more explicit "profiles" with better config defaults.
Detailed solution tracking:
- 🚧 WIP @ProExpertProg: integrating FP8
CustomOp
s from [Perf] Replace per-tensor/token FP8 quant CUDA kernels with torch.compile #18965 into forward-pass. - 🚧 WIP [Feature] Support sequence parallelism for static fp8 quantization #19181 will remove hardcoded
enable_fusion=False
. - 🚧 WIP @gshtras and @SageMoore are collecting some benchmarking numbers for custom op performance. More benchmarking can be done as needed. Automating some of this would be great as well.
- 🕐 TODO @yewentao256 is going to work on consolidating, vectorizing and cleaning up CUDA/HIP implementations of custom ops.
- ❗ TODO add testing utilities for custom ops.
- ❗ TODO @ProExpertProg will write RFC for pattern matching.
➕ Other potential improvements
- Simplify custom op enablement
- Improve custom op documentation
- Custom op microbenchmarks
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.