-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer #20059
Conversation
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @fhl2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a refined approach to full CUDA graph integration within vLLM, moving from a flattened FX graph to a wrapper-based strategy that preserves the piecewise graph structure. This new method facilitates broader full CUDA graph support for attention backends like FlashAttention 2 and FlashInfer, leading to measurable performance gains. Additionally, it includes a fix for a specific kernel compilation issue, enhancing overall system stability and compatibility.
Highlights
- Enhanced Full CUDA Graph Implementation: Introduces a new strategy for full CUDA graph capture that wraps the piecewise compiled FX graph, rather than flattening it. This aims to reduce CPU overhead for non-captured batch sizes and offers greater flexibility, allowing dispatch to different CUDA graph sets for prefill-decode and pure decode stages.
- FA2 and FlashInfer Support: Extends full CUDA graph support to FlashAttention 2 (FA2) and FlashInfer backends. This includes specific adaptations for their distinct prefill-decode and pure decode routines, enabling performance benefits for these attention backends.
- Performance Improvements: Benchmarking results indicate a 5% improvement in output token throughput for FA2 and a 2% improvement for FlashInfer, with corresponding reductions in Time Per Output Token (TPOT) by 2.9% and 3.1% respectively.
- Marlin Kernel Compilation Bug Fix: Addresses a minor bug where Marlin kernels were incorrectly compiled for unsupported GPU architectures (e.g., 8.7 for RTX 4090, which is 8.9), resolving 'RuntimeError: CUDA error: no kernel image is available for execution on the device' errors.
- Separate Attention Routine Configuration: Adds a new
separate_attention_routineflag toCompilationConfig, allowing for distinct CUDA graph capturing for prefill-decode and pure decode stages within attention backends that implement different branches for these cases.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new implementation for full cuda graph, adds support for FA2 and FlashInfer, and fixes a bug for Marlin kernels on Ada architecture. The core idea is to keep the piecewise graph structure and wrap it to capture the full CUDA graph, which avoids the CPU overhead of a large flattened graph. The changes are well-motivated, and the performance improvements are clearly demonstrated.
bcf7cb9 to
c2c5fea
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
|
I have incorporated some checks for the new flag |
|
Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, |
Please let me know If any questions or suggestions. I am currently planning on adding some unit tests. |
Signed-off-by: fhl <2410591650@qq.com>
ProExpertProg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a good approach overall!
My initial feedback:
- I think we should try to consolidate CUDAGraph logic into a single class.
- CUDAGraph logic is complex on
mainalready, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states. - There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.
- This is a large PR, so it might help to split it. e.g. FlashInfer cg support can be added in a follow-up. But I'll let others chime in here.
Okay, this is plenty for now :D - thanks for the PR!
That's right, we only support |
…ogonal to compilation, add support for FA2 and FlashInfer (vllm-project#20059) Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…ogonal to compilation, add support for FA2 and FlashInfer (vllm-project#20059) Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
|
Hi, @fhl2000 Sorry to disturb. We have a confusion for the FULL mode. When setting the FULL mode, the dummy input of attention will be created for capturing the mixed graph(uniform_decode=False). I find the code here has an assumption that a pure prefill will be computed for the mixed graph. For example, according to the following code, the num_tokens in the batch_descriptor is 512 and uniform_decode is False, then the max_query_len is 512 and it will compute the all 512 tokens for prefill, 0 tokens for decode in this batch. Is my understanding right? 😄 Previously I thought the dummy input of mixed batch will be 70% tokens for prefill and 30% tokens for decode. Thank you ! |
That's right. We capture a cudagraph of this edge case (pure prefill) for the general mixed batch, because it is always compatible with a mixed batch, no matter how it is mixed, and even works for a pure decode batch. |
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…-project#23046) Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
…-project#23046) Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…-project#23046) Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…-project#23046) Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…-project#23046) Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
What's changed in this PR:
cudagraph_modeis introduced in CompilationConfig, supporting the following five modes:NONE,PIECEWISE,FULL,FULL_DECODE_ONLY, andFULL_AND_PIECEWISE, which will replace/deprecate the original two flagsuse_cudagraphandfull_cuda_graph. The most powerful mode would beFULL_AND_PIECEWISE, which reduced ITL of decode-phase by full cudagraph and retains small TTFT of piecewise cudagraph for prefill-or-mixed-phase and also keeps generality;FULL_AND_PIECEWISEmode vs. piecewise cudagraph of main branch, see comment for details).Motivation
See the original PR description below. Also, reference @ProExpertProg 's long proposal and the RFC #20283.
Overall designs:
The new CUDA Graph logic is built on top of piecewise compilation and supports dual cudagraph runtime mode switching. To make the system work, there are three core components (i.e.,
CUDAGraphMode,CUDAGraphWrapper, andCUDAGraphDispatcher) and two auxiliary concepts (i.e.,BatchDescriptorandAttentionCGSupport) introduced in this PR. Let's pick up through them one by one:Note: For convenience, hereafter we refer to pure decode (max_query_len=1) or speculative decode (max_query_len =1+num_spec_tokens) as uniform decode batches, and the opposite would be non-uniform batches (i.e., prefill or mixed prefill-decode batches).
CUDAGraphMode
cudagraph_modeis the new flag introduced in CompilationConfig, taking the enum CUDAGraphMode as value. The prototype:Here,
NONEis for no cudagraph.PIECEWISEuses only piecewise cudagraphs (v1 default).FULLmeans a single-mode strategy, which captures full cudagraphs for only non-uniform batches, where uniform-decode batches are just viewed as a special case of non-uniform batches.FULL_DECODE_ONLY, dual mode, capture one set of cudagraph for uniform-decode, and no cudagraph for the rest.FULL_AND_PIECEWISE, dual mode, explicitly having two sets of cudagraph, with full cudagraph for uniform-decode, and piecewise cudagraph for non-uniform batches or any rest. In this way, the cascade attention can also be supported in the modesFULL_DECODE_ONLYandFULL_AND_PIECEWISE(will be addressed in follow-up PRs).Noticeably, we also fuse the subset modes
NONE,PIECEWISE, andFULLas the concrete runtime modes for cudagraph dispatching, so they are treated as one of the decode_mode() or mixed_mode() at runtime.BatchDescriptor
BatchDescriptoris a component at ForwardContext, together with the cudagraph runtime modes, to serve as the core structure of dispatching keys at runtime. The prototype would be:where num_tokens can be the padded token length, and uniform_decode is determined by if max_query_len is equal to uniform_decode_query_len, and the num_schedual_token is divisible by uniform_decode_query_len. I designed this to uniquely identify a (padded) batch with minimal possible items corresponding to a cudagraph item. We are safe to exclude items like max_query_len or num_request because the uniform_decode_query_len is a constant at runtime, and num_request would be unimportant from the view of attention kernels at a padding semantic.
CUDAGraphDispatcher
The dispatcher takes responsibility for creating/storing two sets of valid dispatching keys, one set for FULL runtime mode and one set for
PIECEWISEruntime mode, and dispatches the correct dispatching keys when executing model forwards. It decides the runtime mode and the truth batch descriptor as keys depending on the rough input descriptor (and of course, internally the cudagraph_mode in compilation_config), and then tells the CudgraphWarpper instance (the worker) its decision through forward contexts. We should notice that CUDAGraphDispatcher is the only source of truth for available cudagraph keys, and the CUDAGraphWrapper instances could have less logic and unquestioningly trust the forward context on what cg to dispatch to.The dispatching code is like:
Inside the dispatch() method, the dispatcher will search the proper cudagraph runtime mode and existing dispatching keys for a return. We basically search the existing keys following the priority:
FULL>PIECEWISE. If the dispatching key does not exist, default to returnNONEmode for eager execution.CUDAGraphWrapper
Each
CUDAGraphWrapperinstance wraps a runnable and is bound to a specificruntime_mode, which is restricted toPIECEWISEandFULLmode. It takes responsibility for capturing/replaying and passing through the runnable. At runtime, each wrapper inspects the dispatching keys from the forward context. It only needs to check whether it’s activated (via mode matching) and if not, passes through, otherwise it replays an available cudagraph for a key or captures it if it doesn’t exist. This way, there’s no implicit contract between the dispatcher and the wrapper about, and instead the wrapper directly trusts what’s in the forward context.Nested Wrapper design
The core mechanism of making a full cudagraph and piecewise cudagraph coexist and compatible is the nested cudagraph wrapper design, building on top of piecewise compilation with only a single piecewise fx graph. We wrap a FULL mode wrapper outside the entire model for the full cudagraph functionality; meanwhile, each piecewise backend is wrapped via a
PIECEWISEmode wrapper inside the compilation..
Below is the cropped Image of the flow chart from @ProExpertProg, which should clearly describe how it works.
Therefore, for a
FULLruntime mode, it is safe to capture/replay a full cudagraph since the piecewise wrapper is not activated. The situation is the same forPIECEWISEmode, as there are no conflicts between theFULLmode wrapper andPIECEWISEmode wrappers. ForNONEruntime mode, bothFULLandPIECEWISEwrappers would not be activated, so an eager execution is passed.About the warm-up
The cudagraph wrapper is no longer aware of the wram up logic. The warm-up process is controlled directly by the gpu model runner, where the
NONEruntime mode is assigned so an eager execution is played.CLI reference
Now the CLI is directly using the uppercase string of cudagraph_mode for compilation_config:
--compilation-config '{"cudagraph_mode": "..."}', where...should be one ofNONE,PIECEWISE,FULL,FULL_DECODE_ONLY, andFULL_AND_PIECEWISE. Note that all PIECEWISE related modes require piecewise compilation, and all FULL related modes need cudagraph support of attention backends, which is marked by a new enum type AttentionCGSupport inside metadata_builders.The AttentionCGSupport of an attention backend would be one of
ALWAYS,UNIFORM_BATCH,UNIFORM_SINGLE_TOKEN_DECODE, andNEVER(default). An attention backend withALWAYScg support is reachable for all modes, and a backend withUNIFORM_BATCHorUNIFORM_SINGLE_TOKEN_DECODEonly supportsFULL_DECODE_ONLYandFULL_AND_PIECEWISEmode.For user-facing concerns, we also enable a fallback behavior of
FULLmode, so when using an attention backend whose cg support isUNIFORM_BATCHorUNIFORM_SINGLE_TOKEN_DECODE, theFULLmode would be translated toFULL_AND_PIECEWISEif piecewise compilation is enabled, otherwiseFULL_DECODE_ONLY.NOTE for attention ops fusion:
Currently, the default behavior of cudagraph_mode !=
NONEwould always keep the attention ops in thesplitting_opsto get piecewise fx graph. In case one needs attention ops fusion, or mimic the previous behavior offull_cuda_graph=True, just manually passingsplitting_ops=[]to retain the flattened fx graph, and using cudagraph_mode = "FULL" or "FULL_DECODE_ONLY" (should just avoid the PIECEWISE in mode even though we are using -O3)Origin PR description at 2025/6/25
Purpose
1. This PR introduces a new implementation for full cuda graph, and adds support for FA2 and FlashInfer.
Previous limitations
The original design in PR #16072 is to set
compilation_config.splitting_opsas an empty list and capture the full cudagraph inside the flattened fx graph, which supports FA3 only. In later PR #18581, full cudagraph support for FlashMLA only captures the pure decode stage, and bypasses the mix prefill-decode stages, i.e., it runs the eager code of the compiled flattened fx graph in this stage. However, from the profiling results(see below), I found this flattened graph has performance issues at eager call, which is about 2x slower on the cpu side than the compiled piecewise fx graph running (possibly an issue from Python). This can lead to potential performance degradation when the prefill stage of a small batch size.Also, considering that attention backends, like FA2, FlashInfer, and FlashMLA, have two distinct attention routines for prefill-decode stages and pure decode stages separately, which makes it difficult to contain all in a unified graph and only keeps one set of captured cudagraphs.
Solution of this PR.
So, the new trick is, we keep the piecewise compiled fx graph structure overall, but capture the full cudagraph outside the fx graph via a wrapper. With this at hand, we can dispatch to two sets of cudagraph. For the pure decode stage, directly using full cudagraphs since it is compatible with most attention backends. For mix prefill-decode stages, it can either fall back to piecewise cudagraph for incompatible routines in backends like FlashMLA and FlashInfer, or to use another set of full cudagraph for compatible backends(varlen supports in FA2).
Note that keeping the piecewise compiled fx graph is at least better than a full but flattened one from the viewpoint of reducing cpu overhead, even if we do not capture the mix prefill-decode stage. It is also flexible to switch between full cudagraph and piecewise cudagraph for future extension. For example, seamless fallback to piecewise cudagraph if cascade attention is needed.
The limitation is the increased startup time and more gpu memory required for the additional cudagraph capturing. Maybe we can optimize this by shrinking the list of batch sizes to be captured for the prefill-decode stage.
#profile on compiled flatten fx graph on eager execution, mix prefill-decode stage.
Takes roughly 56ms to fully launch the model. An additional 5ms latency in doing some safety checking before launching the first kernel. It seems Python is slow at executing the flattened and large module without submodules.

Note:
the only way to use flatten fx graph in this PR is to hardcode the splitting_ops =[] inset_splitting_ops_for_v1Manually passing splitting_ops =[] to compilation config should lead to a flattened fx graph. (updated at 2025/8/10)
#profile on compiled piecewise fx graph on eager execution, mix prefill-decode stage.
28 ms to fully launch, and the latency above almost disappears. In fact, they are hidden inside each submodule.

The patterns above are verified on two different machines (ignoring the gpu difference here as this is only related to cpu), tested on Qwen2.5-7B-Instruct-GPTQ-Int4 and profile benchmark_serving (sharegpt, unlimited request rate).
So, if a prefill batch size is a bit larger than the max capturing size (say 512) but not too big, the lower bound of model forward time is possibly bounded by cpu side, around 56ms in running the flattened graph, instead of 28ms for the piecewise one.
Details for supporting FA2:
The previous codes did not recognize the two routines under the FA2 code. It launches a standard varlen fwd kernel on mix prefill-decode batches. or launches another routine for pure decode batches, including an optimization for GQA/MQA and potential flash-decode kernels (split_kv >1). By setting max_query_len =1 or >1 on cuda capturing phase, we can correctly activate the desired attention routine, therefore to be correctly captured. (To be serious, the kernel for prefill-decode phase is, of course, compatible with pure decode, but is not fully optimized for decode phase. The actual reason PR #16072 did not support FA2 is a bug that the seq_lens is a zero tensor in the dummy_run in the early code, which bypasses launching any attention kernel at the capturing phase, leading to zero tensor outputs.)
Details for supporting FlashInfer:
Test Plan
benchmark serving, lm_eval performance of FA2 and FlashInfer
I have no plan to test FlashMLA and FA3 as no hopper gpu at hand, but it should be fine as the current design is compatible with them. However, it would be very nice if somebody could help test them.
Test Result
🟠 NOTE: results below were tested on an initial version of this PR, just kept for reference, as the CLI and code structures have undergone huge refactorings since the initial version (updated at 2025/8/10)
Summary of results
Output token throughput is imporved by 5% for FA2 and 2% for FlashInfer on Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4. TPOT is reduced by 2.9% and 3.1%, respectively. The lm_evel has no changes for both.
Details
machine: A100 40G, torch2.6 cuda12.4
Benchmark serving command:
FA2 benchmark serving:
piecewise cudagraph before this PR
python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9
full cudagraph + piecewise fx graph in this PR
python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'
FA2 lm_eval
piecewise cudagraph before this PR
vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
full cudagraph + piecewise fx graph after this PR
vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
FlashInfer benchmark serving
piecewise cudagraph before this PR
VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9
full cudagraph + piecewise fx graph after this PR
VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'
FlashInfer lm_eval
piecewise cudagraph before this PR
vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1
full cudagraph + piecewise fx graph after this PR
vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1