Skip to content

Conversation

@fhl2000
Copy link
Contributor

@fhl2000 fhl2000 commented Jun 25, 2025

What's changed in this PR:

  1. We allow cudagraph logic to be orthogonal to VLLM compilation in v1. In other words, full cudagraph without compilation is supported in v1 now;
  2. Full cudagraph support for Flash Attention v2 and FlashInfer (moved to #21367 with further optimizations);
  3. New CLI flag cudagraph_mode is introduced in CompilationConfig, supporting the following five modes: NONE, PIECEWISE, FULL, FULL_DECODE_ONLY, and FULL_AND_PIECEWISE, which will replace/deprecate the original two flags use_cudagraph and full_cuda_graph. The most powerful mode would be FULL_AND_PIECEWISE, which reduced ITL of decode-phase by full cudagraph and retains small TTFT of piecewise cudagraph for prefill-or-mixed-phase and also keeps generality;
  4. Cudagraph modes related to full cudagraph are compatible with the validation phase of speculative decode (only FA2/3, Triton attention now; FlashMLA in a follow-up PR);
  5. Ideally, the new infrastructure will speed up most attention backends that support full cudagraph. The ITL for Triton attention was reduced by 38% for small LLMs (FULL_AND_PIECEWISE mode vs. piecewise cudagraph of main branch, see comment for details).

Motivation

See the original PR description below. Also, reference @ProExpertProg 's long proposal and the RFC #20283.

Overall designs:

The new CUDA Graph logic is built on top of piecewise compilation and supports dual cudagraph runtime mode switching. To make the system work, there are three core components (i.e., CUDAGraphMode, CUDAGraphWrapper, and CUDAGraphDispatcher) and two auxiliary concepts (i.e., BatchDescriptor and AttentionCGSupport) introduced in this PR. Let's pick up through them one by one:

Note: For convenience, hereafter we refer to pure decode (max_query_len=1) or speculative decode (max_query_len =1+num_spec_tokens) as uniform decode batches, and the opposite would be non-uniform batches (i.e., prefill or mixed prefill-decode batches).

CUDAGraphMode

cudagraph_mode is the new flag introduced in CompilationConfig, taking the enum CUDAGraphMode as value. The prototype:

class CUDAGraphMode(enum.Enum):
    """Constants for the cudagraph mode in CompilationConfig.
    The subset enums `NONE`, `PIECEWISE`, and `FULL` are also treated as
    concrete runtime modes for cudagraph runtime dispatching.
    """
    NONE = 0
    PIECEWISE = 1
    FULL = 2
    FULL_DECODE_ONLY = (FULL, NONE)
    FULL_AND_PIECEWISE = (FULL, PIECEWISE)

Here, NONE is for no cudagraph. PIECEWISE uses only piecewise cudagraphs (v1 default). FULL means a single-mode strategy, which captures full cudagraphs for only non-uniform batches, where uniform-decode batches are just viewed as a special case of non-uniform batches. FULL_DECODE_ONLY, dual mode, capture one set of cudagraph for uniform-decode, and no cudagraph for the rest. FULL_AND_PIECEWISE, dual mode, explicitly having two sets of cudagraph, with full cudagraph for uniform-decode, and piecewise cudagraph for non-uniform batches or any rest. In this way, the cascade attention can also be supported in the modes FULL_DECODE_ONLY and FULL_AND_PIECEWISE(will be addressed in follow-up PRs).

Noticeably, we also fuse the subset modes NONE, PIECEWISE, and FULL as the concrete runtime modes for cudagraph dispatching, so they are treated as one of the decode_mode() or mixed_mode() at runtime.

BatchDescriptor

BatchDescriptor is a component at ForwardContext, together with the cudagraph runtime modes, to serve as the core structure of dispatching keys at runtime. The prototype would be:

class BatchDescriptor(NamedTuple):
    num_tokens: int
    uniform_decode: bool = False

where num_tokens can be the padded token length, and uniform_decode is determined by if max_query_len is equal to uniform_decode_query_len, and the num_schedual_token is divisible by uniform_decode_query_len. I designed this to uniquely identify a (padded) batch with minimal possible items corresponding to a cudagraph item. We are safe to exclude items like max_query_len or num_request because the uniform_decode_query_len is a constant at runtime, and num_request would be unimportant from the view of attention kernels at a padding semantic.

CUDAGraphDispatcher

The dispatcher takes responsibility for creating/storing two sets of valid dispatching keys, one set for FULL runtime mode and one set for PIECEWISE runtime mode, and dispatches the correct dispatching keys when executing model forwards. It decides the runtime mode and the truth batch descriptor as keys depending on the rough input descriptor (and of course, internally the cudagraph_mode in compilation_config), and then tells the CudgraphWarpper instance (the worker) its decision through forward contexts. We should notice that CUDAGraphDispatcher is the only source of truth for available cudagraph keys, and the CUDAGraphWrapper instances could have less logic and unquestioningly trust the forward context on what cg to dispatch to.

The dispatching code is like:

batch_descriptor=BatchDescriptor(num_tokens=num_input_tokens, uniformed_decode=...)
runtime_mode, batch_descriptor = cudagraphdispatcher.dispatch(batch_descriptor)
# execution
with set_forward_context(..., 
            cudagraph_runtime_mode=runtime_mode, 
            batch_descriptor=batch_descriptor):
     output = self.model(...)

Inside the dispatch() method, the dispatcher will search the proper cudagraph runtime mode and existing dispatching keys for a return. We basically search the existing keys following the priority: FULL>PIECEWISE. If the dispatching key does not exist, default to return NONE mode for eager execution.

CUDAGraphWrapper

Each CUDAGraphWrapper instance wraps a runnable and is bound to a specific runtime_mode, which is restricted to PIECEWISE and FULL mode. It takes responsibility for capturing/replaying and passing through the runnable. At runtime, each wrapper inspects the dispatching keys from the forward context. It only needs to check whether it’s activated (via mode matching) and if not, passes through, otherwise it replays an available cudagraph for a key or captures it if it doesn’t exist. This way, there’s no implicit contract between the dispatcher and the wrapper about, and instead the wrapper directly trusts what’s in the forward context.

Nested Wrapper design

The core mechanism of making a full cudagraph and piecewise cudagraph coexist and compatible is the nested cudagraph wrapper design, building on top of piecewise compilation with only a single piecewise fx graph. We wrap a FULL mode wrapper outside the entire model for the full cudagraph functionality; meanwhile, each piecewise backend is wrapped via a PIECEWISE mode wrapper inside the compilation.
.
Below is the cropped Image of the flow chart from @ProExpertProg, which should clearly describe how it works.

image

Therefore, for a FULL runtime mode, it is safe to capture/replay a full cudagraph since the piecewise wrapper is not activated. The situation is the same for PIECEWISE mode, as there are no conflicts between the FULL mode wrapper and PIECEWISE mode wrappers. For NONE runtime mode, both FULL and PIECEWISE wrappers would not be activated, so an eager execution is passed.

About the warm-up

The cudagraph wrapper is no longer aware of the wram up logic. The warm-up process is controlled directly by the gpu model runner, where the NONE runtime mode is assigned so an eager execution is played.

CLI reference

Now the CLI is directly using the uppercase string of cudagraph_mode for compilation_config: --compilation-config '{"cudagraph_mode": "..."}', where ... should be one of NONE, PIECEWISE, FULL, FULL_DECODE_ONLY, and FULL_AND_PIECEWISE. Note that all PIECEWISE related modes require piecewise compilation, and all FULL related modes need cudagraph support of attention backends, which is marked by a new enum type AttentionCGSupport inside metadata_builders.

The AttentionCGSupport of an attention backend would be one of ALWAYS, UNIFORM_BATCH, UNIFORM_SINGLE_TOKEN_DECODE, and NEVER(default). An attention backend with ALWAYS cg support is reachable for all modes, and a backend with UNIFORM_BATCH or UNIFORM_SINGLE_TOKEN_DECODE only supports FULL_DECODE_ONLY and FULL_AND_PIECEWISE mode.

For user-facing concerns, we also enable a fallback behavior of FULL mode, so when using an attention backend whose cg support is UNIFORM_BATCH or UNIFORM_SINGLE_TOKEN_DECODE, the FULL mode would be translated to FULL_AND_PIECEWISE if piecewise compilation is enabled, otherwise FULL_DECODE_ONLY.

NOTE for attention ops fusion:
Currently, the default behavior of cudagraph_mode != NONE would always keep the attention ops in the splitting_ops to get piecewise fx graph. In case one needs attention ops fusion, or mimic the previous behavior of full_cuda_graph=True, just manually passing splitting_ops=[] to retain the flattened fx graph, and using cudagraph_mode = "FULL" or "FULL_DECODE_ONLY" (should just avoid the PIECEWISE in mode even though we are using -O3)


Origin PR description at 2025/6/25

Purpose

1. This PR introduces a new implementation for full cuda graph, and adds support for FA2 and FlashInfer.

Previous limitations

The original design in PR #16072 is to set compilation_config.splitting_ops as an empty list and capture the full cudagraph inside the flattened fx graph, which supports FA3 only. In later PR #18581, full cudagraph support for FlashMLA only captures the pure decode stage, and bypasses the mix prefill-decode stages, i.e., it runs the eager code of the compiled flattened fx graph in this stage. However, from the profiling results(see below), I found this flattened graph has performance issues at eager call, which is about 2x slower on the cpu side than the compiled piecewise fx graph running (possibly an issue from Python). This can lead to potential performance degradation when the prefill stage of a small batch size.

Also, considering that attention backends, like FA2, FlashInfer, and FlashMLA, have two distinct attention routines for prefill-decode stages and pure decode stages separately, which makes it difficult to contain all in a unified graph and only keeps one set of captured cudagraphs.

Solution of this PR.

So, the new trick is, we keep the piecewise compiled fx graph structure overall, but capture the full cudagraph outside the fx graph via a wrapper. With this at hand, we can dispatch to two sets of cudagraph. For the pure decode stage, directly using full cudagraphs since it is compatible with most attention backends. For mix prefill-decode stages, it can either fall back to piecewise cudagraph for incompatible routines in backends like FlashMLA and FlashInfer, or to use another set of full cudagraph for compatible backends(varlen supports in FA2).

Note that keeping the piecewise compiled fx graph is at least better than a full but flattened one from the viewpoint of reducing cpu overhead, even if we do not capture the mix prefill-decode stage. It is also flexible to switch between full cudagraph and piecewise cudagraph for future extension. For example, seamless fallback to piecewise cudagraph if cascade attention is needed.

The limitation is the increased startup time and more gpu memory required for the additional cudagraph capturing. Maybe we can optimize this by shrinking the list of batch sizes to be captured for the prefill-decode stage.

#profile on compiled flatten fx graph on eager execution, mix prefill-decode stage.

Takes roughly 56ms to fully launch the model. An additional 5ms latency in doing some safety checking before launching the first kernel. It seems Python is slow at executing the flattened and large module without submodules.
image

Note: the only way to use flatten fx graph in this PR is to hardcode the splitting_ops =[] in set_splitting_ops_for_v1
Manually passing splitting_ops =[] to compilation config should lead to a flattened fx graph. (updated at 2025/8/10)

#profile on compiled piecewise fx graph on eager execution, mix prefill-decode stage.

28 ms to fully launch, and the latency above almost disappears. In fact, they are hidden inside each submodule.
image

The patterns above are verified on two different machines (ignoring the gpu difference here as this is only related to cpu), tested on Qwen2.5-7B-Instruct-GPTQ-Int4 and profile benchmark_serving (sharegpt, unlimited request rate).

So, if a prefill batch size is a bit larger than the max capturing size (say 512) but not too big, the lower bound of model forward time is possibly bounded by cpu side, around 56ms in running the flattened graph, instead of 28ms for the piecewise one.

Details for supporting FA2:

The previous codes did not recognize the two routines under the FA2 code. It launches a standard varlen fwd kernel on mix prefill-decode batches. or launches another routine for pure decode batches, including an optimization for GQA/MQA and potential flash-decode kernels (split_kv >1). By setting max_query_len =1 or >1 on cuda capturing phase, we can correctly activate the desired attention routine, therefore to be correctly captured. (To be serious, the kernel for prefill-decode phase is, of course, compatible with pure decode, but is not fully optimized for decode phase. The actual reason PR #16072 did not support FA2 is a bug that the seq_lens is a zero tensor in the dummy_run in the early code, which bypasses launching any attention kernel at the capturing phase, leading to zero tensor outputs.)

  • FA2 runs both mix prefill-decode and pure decode batches at full cudagraph, but on two separate sets of cudagraphs.

Details for supporting FlashInfer:

  • Using the persistent buffer trick.
  • Create many decode_warpers, one for a cudagraph batch size, as this is required by the FlashInfer API.
  • Run pure decode batches at full cudagraph, and fall back to piecewise cudagraph at mix prefill-decode batches.

Test Plan

benchmark serving, lm_eval performance of FA2 and FlashInfer

I have no plan to test FlashMLA and FA3 as no hopper gpu at hand, but it should be fine as the current design is compatible with them. However, it would be very nice if somebody could help test them.

Test Result

🟠 NOTE: results below were tested on an initial version of this PR, just kept for reference, as the CLI and code structures have undergone huge refactorings since the initial version (updated at 2025/8/10)

Summary of results

Output token throughput is imporved by 5% for FA2 and 2% for FlashInfer on Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4. TPOT is reduced by 2.9% and 3.1%, respectively. The lm_evel has no changes for both.

Details

machine: A100 40G, torch2.6 cuda12.4

Benchmark serving command:

python benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 20

FA2 benchmark serving:

piecewise cudagraph before this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.41
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 8.77
Output token throughput (tok/s): 1898.67
Total Token throughput (tok/s): 3937.88
---------------Time to First Token----------------
Mean TTFT (ms): 76.37
Median TTFT (ms): 71.08
P99 TTFT (ms): 191.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.08
Median TPOT (ms): 15.22
P99 TPOT (ms): 67.68
---------------Inter-token Latency----------------
Mean ITL (ms): 13.45
Median ITL (ms): 11.05
P99 ITL (ms): 72.61
==================================================

full cudagraph + piecewise fx graph in this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 10.87
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 9.20
Output token throughput (tok/s): 1992.27
Total Token throughput (tok/s): 4132.01
---------------Time to First Token----------------
Mean TTFT (ms): 78.69
Median TTFT (ms): 75.10
P99 TTFT (ms): 195.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.57
Median TPOT (ms): 14.78
P99 TPOT (ms): 78.21
---------------Inter-token Latency----------------
Mean ITL (ms): 12.83
Median ITL (ms): 10.34
P99 ITL (ms): 72.37
==================================================

FA2 lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

FlashInfer benchmark serving

piecewise cudagraph before this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.36
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.81
Output token throughput (tok/s): 1907.38
Total Token throughput (tok/s): 3955.65
---------------Time to First Token----------------
Mean TTFT (ms): 73.61
Median TTFT (ms): 69.59
P99 TTFT (ms): 184.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.85
Median TPOT (ms): 15.13
P99 TPOT (ms): 65.75
---------------Inter-token Latency----------------
Mean ITL (ms): 13.34
Median ITL (ms): 11.09
P99 ITL (ms): 71.82
==================================================

full cudagraph + piecewise fx graph after this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.13
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.99
Output token throughput (tok/s): 1946.35
Total Token throughput (tok/s): 4036.48
---------------Time to First Token----------------
Mean TTFT (ms): 76.03
Median TTFT (ms): 67.04
P99 TTFT (ms): 192.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.34
Median TPOT (ms): 14.96
P99 TPOT (ms): 58.86
---------------Inter-token Latency----------------
Mean ITL (ms): 13.11
Median ITL (ms): 10.71
P99 ITL (ms): 71.69
==================================================

FlashInfer lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

fhl2000 added 2 commits June 25, 2025 13:36
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fhl2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refined approach to full CUDA graph integration within vLLM, moving from a flattened FX graph to a wrapper-based strategy that preserves the piecewise graph structure. This new method facilitates broader full CUDA graph support for attention backends like FlashAttention 2 and FlashInfer, leading to measurable performance gains. Additionally, it includes a fix for a specific kernel compilation issue, enhancing overall system stability and compatibility.

Highlights

  • Enhanced Full CUDA Graph Implementation: Introduces a new strategy for full CUDA graph capture that wraps the piecewise compiled FX graph, rather than flattening it. This aims to reduce CPU overhead for non-captured batch sizes and offers greater flexibility, allowing dispatch to different CUDA graph sets for prefill-decode and pure decode stages.
  • FA2 and FlashInfer Support: Extends full CUDA graph support to FlashAttention 2 (FA2) and FlashInfer backends. This includes specific adaptations for their distinct prefill-decode and pure decode routines, enabling performance benefits for these attention backends.
  • Performance Improvements: Benchmarking results indicate a 5% improvement in output token throughput for FA2 and a 2% improvement for FlashInfer, with corresponding reductions in Time Per Output Token (TPOT) by 2.9% and 3.1% respectively.
  • Marlin Kernel Compilation Bug Fix: Addresses a minor bug where Marlin kernels were incorrectly compiled for unsupported GPU architectures (e.g., 8.7 for RTX 4090, which is 8.9), resolving 'RuntimeError: CUDA error: no kernel image is available for execution on the device' errors.
  • Separate Attention Routine Configuration: Adds a new separate_attention_routine flag to CompilationConfig, allowing for distinct CUDA graph capturing for prefill-decode and pure decode stages within attention backends that implement different branches for these cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new implementation for full cuda graph, adds support for FA2 and FlashInfer, and fixes a bug for Marlin kernels on Ada architecture. The core idea is to keep the piecewise graph structure and wrap it to capture the full CUDA graph, which avoids the CPU overhead of a large flattened graph. The changes are well-motivated, and the performance improvements are clearly demonstrated.

Signed-off-by: fhl <2410591650@qq.com>
@fhl2000 fhl2000 force-pushed the full_cudagraph_FA2_FlashInfer branch from bcf7cb9 to c2c5fea Compare June 25, 2025 08:33
@mergify
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
fhl2000 and others added 2 commits June 25, 2025 16:52
@mergify mergify bot removed the needs-rebase label Jun 25, 2025
fhl2000 added 2 commits June 25, 2025 10:03
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 25, 2025

I have incorporated some checks for the new flag separate_attention_routine, so it is safe to launch now. This PR is now ready to be reviewed!

@fhl2000 fhl2000 marked this pull request as ready for review June 25, 2025 14:53
@fhl2000 fhl2000 changed the title [Core][Bugfix] new way for full cudagraph, add support for FA2 and FlashInfer; a minor bug fixed [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer; A minor bug fixed Jun 25, 2025
@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

@fhl2000
Copy link
Contributor Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

Please let me know If any questions or suggestions. I am currently planning on adding some unit tests.

Signed-off-by: fhl <2410591650@qq.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good approach overall!
My initial feedback:

  • I think we should try to consolidate CUDAGraph logic into a single class.
  • CUDAGraph logic is complex on main already, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states.
  • There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.
  • This is a large PR, so it might help to split it. e.g. FlashInfer cg support can be added in a follow-up. But I'll let others chime in here.

Okay, this is plenty for now :D - thanks for the PR!

@Isotr0py
Copy link
Member

So my understanding is, you can basically treat the usage of cudagraph_mode of a multi-modal model the same as a language model, as long as it has a language model component, and only this component benefits from cudagraph(also torch.compile) speed up.

That's right, we only support torch.compile on multimodal model's text backbone currently. But support torch.compile on ViT is also on our roadmap and under investigation! (Multi-modality Core (view))

xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…ogonal to compilation, add support for FA2 and FlashInfer (vllm-project#20059)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…ogonal to compilation, add support for FA2 and FlashInfer (vllm-project#20059)

Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@zejunchen-zejun
Copy link
Contributor

zejunchen-zejun commented Aug 29, 2025

Hi, @fhl2000

Sorry to disturb. We have a confusion for the FULL mode. When setting the FULL mode, the dummy input of attention will be created for capturing the mixed graph(uniform_decode=False). I find the code here has an assumption that a pure prefill will be computed for the mixed graph. For example, according to the following code, the num_tokens in the batch_descriptor is 512 and uniform_decode is False, then the max_query_len is 512 and it will compute the all 512 tokens for prefill, 0 tokens for decode in this batch. Is my understanding right? 😄 Previously I thought the dummy input of mixed batch will be 70% tokens for prefill and 30% tokens for decode.
https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L2248

        max_query_len = self.uniform_decode_query_len if uniform_decode else \
                                                                num_tokens

Thank you !

@fhl2000
Copy link
Contributor Author

fhl2000 commented Aug 29, 2025

I find the code here has an assumption that a pure prefill will be computed for the mixed graph.

That's right. We capture a cudagraph of this edge case (pure prefill) for the general mixed batch, because it is always compatible with a mixed batch, no matter how it is mixed, and even works for a pure decode batch.

zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@fhl2000 fhl2000 mentioned this pull request Sep 6, 2025
5 tasks
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
ProExpertProg added a commit that referenced this pull request Sep 26, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
…-project#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…-project#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…-project#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
alhridoy pushed a commit to alhridoy/vllm that referenced this pull request Oct 24, 2025
…-project#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…-project#23046)

Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build llama Related to Llama models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.