Skip to content

[Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer #20059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

fhl2000
Copy link

@fhl2000 fhl2000 commented Jun 25, 2025

Purpose

1. This PR introduces a new implementation for full cuda graph, and adds support for FA2 and FlashInfer.

Previous limitations

The original design in PR #16072 is to set compilation_config.splitting_ops as an empty list and capture the full cudagraph inside the flattened fx graph, which supports FA3 only. In later PR #18581, full cudagraph support for FlashMLA only captures the pure decode stage, and bypasses the mix prefill-decode stages, i.e., it runs the eager code of the compiled flattened fx graph in this stage. However, from the profiling results(see below), I found this flattened graph has performance issues at eager call, which is about 2x slower on the cpu side than the compiled piecewise fx graph running (possibly an issue from Python). This can lead to potential performance degradation when the prefill stage of a small batch size.

Also, considering that attention backends, like FA2, FlashInfer, and FlashMLA, have two distinct attention routines for prefill-decode stages and pure decode stages separately, which makes it difficult to contain all in a unified graph and only keeps one set of captured cudagraphs.

Solution of this PR.

So, the new trick is, we keep the piecewise compiled fx graph structure overall, but capture the full cudagraph outside the fx graph via a wrapper. With this at hand, we can dispatch to two sets of cudagraph. For the pure decode stage, directly using full cudagraphs since it is compatible with most attention backends. For mix prefill-decode stages, it can either fall back to piecewise cudagraph for incompatible routines in backends like FlashMLA and FlashInfer, or to use another set of full cudagraph for compatible backends(varlen supports in FA2).

Note that keeping the piecewise compiled fx graph is at least better than a full but flattened one from the viewpoint of reducing cpu overhead, even if we do not capture the mix prefill-decode stage. It is also flexible to switch between full cudagraph and piecewise cudagraph for future extension. For example, seamless fallback to piecewise cudagraph if cascade attention is needed.

The limitation is the increased startup time and more gpu memory required for the additional cudagraph capturing. Maybe we can optimize this by shrinking the list of batch sizes to be captured for the prefill-decode stage.

#profile on compiled flatten fx graph on eager execution, mix prefill-decode stage.

Takes roughly 56ms to fully launch the model. An additional 5ms latency in doing some safety checking before launching the first kernel. It seems Python is slow at executing the flattened and large module without submodules.
image

Note: the only way to use flatten fx graph in this PR is to hardcode the splitting_ops =[] in set_splitting_ops_for_v1 (around line 4200 in vllm/config.py)

#profile on compiled piecewise fx graph on eager execution, mix prefill-decode stage.

28 ms to fully launch, and the latency above almost disappears. In fact, they are hidden inside each submodule.
image

The patterns above are verified on two different machines (ignoring the gpu difference here as this is only related to cpu), tested on Qwen2.5-7B-Instruct-GPTQ-Int4 and profile benchmark_serving (sharegpt, unlimited request rate).

So, if a prefill batch size is a bit larger than the max capturing size (say 512) but not too big, the lower bound of model forward time is possibly bounded by cpu side, around 56ms in running the flattened graph, instead of 28ms for the piecewise one.

Details for supporting FA2:

The previous codes did not recognize the two routines under the FA2 code. It launches a standard varlen fwd kernel on mix prefill-decode batches. or launches another routine for pure decode batches, including an optimization for GQA/MQA and potential flash-decode kernels (split_kv >1). By setting max_query_len =1 or >1 on cuda capturing phase, we can correctly activate the desired attention routine, therefore to be correctly captured. (To be serious, the kernel for prefill-decode phase is, of course, compatible with pure decode, but is not fully optimized for decode phase. The actual reason PR #16072 did not support FA2 is a bug that the seq_lens is a zero tensor in the dummy_run in the early code, which bypasses launching any attention kernel at the capturing phase, leading to zero tensor outputs.)

  • FA2 runs both mix prefill-decode and pure decode batches at full cudagraph, but on two separate sets of cudagraphs.

Details for supporting FlashInfer:

  • Using the persistent buffer trick.
  • Create many decode_warpers, one for a cudagraph batch size, as this is required by the FlashInfer API.
  • Run pure decode batches at full cudagraph, and fall back to piecewise cudagraph at mix prefill-decode batches.

Launching command examples:

For FA2:

VLLM_FLASH_ATTN_VERSION=2 python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --compilation-config '{"full_cuda_graph":true, "separate_attention_routine":true}'

For FlashInfer:

VLLM_ATTENTION_BACKEND=FLASHINFER python -m ... --compilation-config '{"full_cuda_graph":true,"separate_attention_routine":true}'

others:
FlashMLA: the compilation-config is '{"full_cuda_graph":true,"separate_attention_routine":true}'
FA3: env set VLLM_FLASH_ATTN_VERSION=3 and the compilation-config is '{"full_cuda_graph":true}'

Test Plan

benchmark serving, lm_eval performance of FA2 and FlashInfer

I have no plan to test FlashMLA and FA3 as no hopper gpu at hand, but it should be fine as the current design is compatible with them. However, it would be very nice if somebody could help test them.

Test Result

Summary of results

Output token throughput is imporved by 5% for FA2 and 2% for FlashInfer on Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4. TPOT is reduced by 2.9% and 3.1%, respectively. The lm_evel has no changes for both.

Details

machine: A100 40G, torch2.6 cuda12.4

Benchmark serving command:

python benchmarks/benchmark_serving.py --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100 --request-rate 20

FA2 benchmark serving:

piecewise cudagraph before this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.41
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 8.77
Output token throughput (tok/s): 1898.67
Total Token throughput (tok/s): 3937.88
---------------Time to First Token----------------
Mean TTFT (ms): 76.37
Median TTFT (ms): 71.08
P99 TTFT (ms): 191.53
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 17.08
Median TPOT (ms): 15.22
P99 TPOT (ms): 67.68
---------------Inter-token Latency----------------
Mean ITL (ms): 13.45
Median ITL (ms): 11.05
P99 ITL (ms): 72.61
==================================================

full cudagraph + piecewise fx graph in this PR

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 10.87
Total input tokens: 23260
Total generated tokens: 21657
Request throughput (req/s): 9.20
Output token throughput (tok/s): 1992.27
Total Token throughput (tok/s): 4132.01
---------------Time to First Token----------------
Mean TTFT (ms): 78.69
Median TTFT (ms): 75.10
P99 TTFT (ms): 195.90
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.57
Median TPOT (ms): 14.78
P99 TPOT (ms): 78.21
---------------Inter-token Latency----------------
Mean ITL (ms): 12.83
Median ITL (ms): 10.34
P99 ITL (ms): 72.37
==================================================

FA2 lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8074 ± 0.0109
strict-match 5 exact_match 0.7619 ± 0.0117

FlashInfer benchmark serving

piecewise cudagraph before this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.36
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.81
Output token throughput (tok/s): 1907.38
Total Token throughput (tok/s): 3955.65
---------------Time to First Token----------------
Mean TTFT (ms): 73.61
Median TTFT (ms): 69.59
P99 TTFT (ms): 184.62
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.85
Median TPOT (ms): 15.13
P99 TPOT (ms): 65.75
---------------Inter-token Latency----------------
Mean ITL (ms): 13.34
Median ITL (ms): 11.09
P99 ITL (ms): 71.82
==================================================

full cudagraph + piecewise fx graph after this PR

VLLM_ATTENTION_BACKEND=FLASHINFER python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4 --gpu-memory-utilization 0.9 --compilation-config '{"full_cuda_graph": true,"separate_attention_routine": true}'

============ Serving Benchmark Result ============
Successful requests: 100
Benchmark duration (s): 11.13
Total input tokens: 23260
Total generated tokens: 21660
Request throughput (req/s): 8.99
Output token throughput (tok/s): 1946.35
Total Token throughput (tok/s): 4036.48
---------------Time to First Token----------------
Mean TTFT (ms): 76.03
Median TTFT (ms): 67.04
P99 TTFT (ms): 192.56
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 16.34
Median TPOT (ms): 14.96
P99 TPOT (ms): 58.86
---------------Inter-token Latency----------------
Mean ITL (ms): 13.11
Median ITL (ms): 10.71
P99 ITL (ms): 71.69
==================================================

FlashInfer lm_eval

piecewise cudagraph before this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

full cudagraph + piecewise fx graph after this PR

vllm ({'pretrained': '/root/models/Qwen2.5-7B-Instruct-GPTQ-Int4', 'gpu_memory_utilization': 0.9, 'compilation_config': {'full_cuda_graph': True, 'separate_attention_routine': True}}), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8105 ± 0.0108
strict-match 5 exact_match 0.7635 ± 0.0117

One more thing, after merging some code from the main branch recently, I ran into a potential deadlock when testing this PR. This should be caused by an early merged code, and PR #19927 seems to solve the problem.

fhl2000 added 2 commits June 25, 2025 13:36
Signed-off-by: fhl <2410591650@qq.com>
Signed-off-by: fhl <2410591650@qq.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @fhl2000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a refined approach to full CUDA graph integration within vLLM, moving from a flattened FX graph to a wrapper-based strategy that preserves the piecewise graph structure. This new method facilitates broader full CUDA graph support for attention backends like FlashAttention 2 and FlashInfer, leading to measurable performance gains. Additionally, it includes a fix for a specific kernel compilation issue, enhancing overall system stability and compatibility.

Highlights

  • Enhanced Full CUDA Graph Implementation: Introduces a new strategy for full CUDA graph capture that wraps the piecewise compiled FX graph, rather than flattening it. This aims to reduce CPU overhead for non-captured batch sizes and offers greater flexibility, allowing dispatch to different CUDA graph sets for prefill-decode and pure decode stages.
  • FA2 and FlashInfer Support: Extends full CUDA graph support to FlashAttention 2 (FA2) and FlashInfer backends. This includes specific adaptations for their distinct prefill-decode and pure decode routines, enabling performance benefits for these attention backends.
  • Performance Improvements: Benchmarking results indicate a 5% improvement in output token throughput for FA2 and a 2% improvement for FlashInfer, with corresponding reductions in Time Per Output Token (TPOT) by 2.9% and 3.1% respectively.
  • Marlin Kernel Compilation Bug Fix: Addresses a minor bug where Marlin kernels were incorrectly compiled for unsupported GPU architectures (e.g., 8.7 for RTX 4090, which is 8.9), resolving 'RuntimeError: CUDA error: no kernel image is available for execution on the device' errors.
  • Separate Attention Routine Configuration: Adds a new separate_attention_routine flag to CompilationConfig, allowing for distinct CUDA graph capturing for prefill-decode and pure decode stages within attention backends that implement different branches for these cases.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new implementation for full cuda graph, adds support for FA2 and FlashInfer, and fixes a bug for Marlin kernels on Ada architecture. The core idea is to keep the piecewise graph structure and wrap it to capture the full CUDA graph, which avoids the CPU overhead of a large flattened graph. The changes are well-motivated, and the performance improvements are clearly demonstrated.

Comment on lines 1918 to 1920
# [Bugfix] This lets FA2 to correctly activate the optimized routine
# for pure decoding, i.e., Flashdecoding + an optimization for GQA/MQA.
max_query_len = 1 if is_pure_decoding else num_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change ensures that the max_query_len is set to 1 during pure decoding, which is a bug fix for FA2 to correctly activate the optimized routine.

CMakeLists.txt Outdated
@@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change adds support for CUDA architecture 8.9, which corresponds to the Ada Lovelace architecture. This resolves a runtime error encountered when compiling Marlin kernels for this architecture.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put these in a separate PR, they are not related to CUDAGraph fixes.

Also @LucasWilkinson I thought we removed these arches on purpose?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed now. I'll put it in a new PR.

CMakeLists.txt Outdated
@@ -684,7 +684,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change adds support for CUDA architecture 8.9, which corresponds to the Ada Lovelace architecture. This resolves a runtime error encountered when compiling Marlin kernels for this architecture.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -139,7 +139,7 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() >= 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The FlashAttentionMetadataBuilder now supports FA2 and later versions, not just FA3.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just set it to True, there's no FA1 integration

vllm/config.py Outdated
Comment on lines 4189 to 4190
if self.separate_attention_routine:
assert self.full_cuda_graph, "separate_attention_routine requires full_cuda_graph to be True"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This assertion ensures that separate_attention_routine is only enabled when full_cuda_graph is also enabled, maintaining the intended usage pattern.

Comment on lines 2244 to 2251
# Skip capturing batch sizes of 1 in mix prefill-decode if
# separate_attention_routine is on. As bs=1 can treat as a
# pure decode.
start_idx = 0
if self.vllm_config.compilation_config.separate_attention_routine \
and len(self.cudagraph_batch_sizes) > 0 \
and self.cudagraph_batch_sizes[0] == 1:
start_idx = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change skips capturing batch sizes of 1 in mix prefill-decode if separate_attention_routine is on, as bs=1 can be treated as a pure decode.

Signed-off-by: fhl <2410591650@qq.com>
@fhl2000 fhl2000 force-pushed the full_cudagraph_FA2_FlashInfer branch from bcf7cb9 to c2c5fea Compare June 25, 2025 08:33
Copy link

mergify bot commented Jun 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 25, 2025
fhl2000 and others added 2 commits June 25, 2025 16:52
@mergify mergify bot removed the needs-rebase label Jun 25, 2025
fhl2000 added 2 commits June 25, 2025 10:03
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000
Copy link
Author

fhl2000 commented Jun 25, 2025

I have incorporated some checks for the new flag separate_attention_routine, so it is safe to launch now. This PR is now ready to be reviewed!

@fhl2000 fhl2000 marked this pull request as ready for review June 25, 2025 14:53
@fhl2000 fhl2000 changed the title [Core][Bugfix] new way for full cudagraph, add support for FA2 and FlashInfer; a minor bug fixed [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer; A minor bug fixed Jun 25, 2025
@fhl2000
Copy link
Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

@fhl2000
Copy link
Author

fhl2000 commented Jun 26, 2025

Here is the workflow. At the initialization of torch.compile, the vllm_backend will warp the split_gm into a full cudagraph warpper class if compilation_config.full_cuda_graph is on. Then this warper class takes responsibility for dispatching to the cudagraph entries of separate attention routines. At runtime, this dispatching is based on two key flags in the global forward_context, skip_attention_cuda_graphs and is_pure_decoding. While skip_attention_cuda_graphs is true, which implies using full cudagraph, this wrapper class will take care of it. That is, when separate_attention_backend is on, the wrapper class furtherly dispatches to decode-only full cudagraph or mix prefill-decode full cudagraph, according to the is_pure_decoding flag. On the other hand, if skip_attention_cuda_graphs is false, the wrapper class immediately falls back to the piecewise fx graph (the original split_gm), which relies on the CUDAPiecewiseBackend class to take on the piecewise cudagraph logic.

Please let me know If any questions or suggestions. I am currently planning on adding some unit tests.

Signed-off-by: fhl <2410591650@qq.com>
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good approach overall!
My initial feedback:

  • I think we should try to consolidate CUDAGraph logic into a single class.
  • CUDAGraph logic is complex on main already, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states.
  • There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.
  • This is a large PR, so it might help to split it. e.g. FlashInfer cg support can be added in a follow-up. But I'll let others chime in here.

Okay, this is plenty for now :D - thanks for the PR!

CMakeLists.txt Outdated
@@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put these in a separate PR, they are not related to CUDAGraph fixes.

Also @LucasWilkinson I thought we removed these arches on purpose?

CMakeLists.txt Outdated
@@ -684,7 +684,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")

list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}")
# 9.0 for latest bf16 atomicAdd PTX
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;9.0+PTX" "${CUDA_ARCHS}")
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.7;8.9;9.0+PTX" "${CUDA_ARCHS}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

)

def __call__(self, *args) -> Any:
if not self.first_run_finished:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much all of this logic is shared with piecewise; could we extract it into CUDAGraph(Runner|Manager|Wrapper...) or something like that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I know there is some logic that is different (for skipping cudagraphs and avoiding double capture), which will probably require an extra flag: can you still rearchitect this and we can bikeshed the exact flags/logic after?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you still rearchitect this and we can bikeshed the exact flags/logic after

Yes, of course. I will give it a try tomorrow

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ProExpertProg I have just redesigned this. Would you kindly review it?

"full_cuda_graph mode requires use_cudagraph to be True"
fullgraph_wrapper = resolve_obj_by_qualname(
current_platform.get_fullgraph_wrapper_cls())
self.split_gm = fullgraph_wrapper(self.split_gm, self.vllm_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why this has to be platform-specific. If it doesn't, let's create it directly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I follow the convention in the class PiecewiseCompileInterpreter, where the piecewise_backend is resolved as platform-specific. It seems CUDAPiecewiseBackend support is limited on the cuda and rocm platforms.

@@ -139,7 +139,7 @@ def _get_sliding_window_configs(

class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() >= 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just set it to True, there's no FA1 integration

@@ -50,6 +50,9 @@ class CommonAttentionMetadata:
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder support CUDA Graphs for attention.
full_cudagraph_supported: ClassVar[bool] = False
# If full cudagraph support, select if this attention backend
# enforce separate rountine to be True, False or None (free).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the difference between False and None here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to confirm that this flag is a bit confusing. My intention is, if force_separate_routine is True (/ False), the backend only supports separate_attention_routine = True (/ False). And if force_separate_routine is None, it leaves space for the user to determine separate_attention_routine to be True or False, e.g., FA2 (or 3 too) is compatible with both settings.

Since this flag now seems confusing, do you have suggestions for clearer naming?

Comment on lines 1338 to 1339
skip_attention_cuda_graphs = not attention_cuda_graphs \
if self.full_cuda_graph else True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
skip_attention_cuda_graphs = not attention_cuda_graphs \
if self.full_cuda_graph else True
skip_attention_cuda_graphs = not self.full_cuda_graph or not attention_cuda_graphs

# cudagraphs that skip the attention part. By default true, we use piecewise
# cudagraphs.
skip_attention_cuda_graphs: bool = True
is_pure_decoding: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would name this is_pure_decode or is_pure_decode_batch (but matter of preference)

@@ -1907,8 +1916,17 @@ def _dummy_run(
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)

# [Bugfix] This lets FA2 to correctly activate the optimized routine
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove [Bugfix] from comment - when the code is merged, it's no longer a bugfix as it's static code right?

@@ -3984,6 +3984,14 @@ class CompilationConfig:
splitting certain operations such as attention into subgraphs. Thus this
flag cannot be used together with splitting_ops. This may provide
performance benefits for smaller models."""
separate_attention_routine: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph? I also don't understand why this has to be a flag and we can't just ask the attention backend what it wants?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we must leave such a flag in the global config, which tells the compiler backend to do the right thing. Otherwise, how is the attention backend supposed to communicate its requirements to the compiler? At least for now, the force_separate_routine flag of an attention backend has the ability to enforce its preference during the initialize_attn_backend phase of the gpu model runner.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be named better. Perhaps split_attn_cudagraph?

I am not sure what name can be better. Btw, I'm afraid split_attn_cudagraph is not a good name. It sounds like splitting the full graph into be piecewise graph, where attn ops are the splitting ops, like what we have already done.

@zou3519
Copy link
Collaborator

zou3519 commented Jun 26, 2025

cc @BoyuanFeng

@fhl2000
Copy link
Author

fhl2000 commented Jun 27, 2025

@ProExpertProg Thanks for the numerous valuable comments.

  • I think we should try to consolidate CUDAGraph logic into a single class.
  • CUDAGraph logic is complex on main already, and this PR increases complexity significantly. We should add significantly more documentation. I also think we should consolidate various config flags and states.

Yes, I totally agree with you. I am trying to figure out how to put the common CUDAGraph logic into a single class while maintaining its relations with CUDAPiecewiseBackend and FullCudagraphWrapper. Maybe just a functional class is enough; the list of ConcreteSizeEntry is still managed within the previous two classes, and calls this functional class after the right entry is selected. Does it match your thoughts? If it does, I will give it a try soon.

  • There are benefits to compilation without splitting the graph (e.g. attention+quant fusion). We should add a new flag that maintains that ability (and assert the attention backend supports full cudagraph only). CUDAGraph logic can stay in the wrapper class.

Okay, I just noticed this recent PR of attention+quant fusion. There appends a fused_output_quant_supported flag, so it must work in a graph without splits. Btw, I am a bit confused by the words "attention backend supports full cudagraph only", would that mean an attention backend that supports cudagraph and supports separate_attention_routine=False (like in FA2 and 3), but may also support separate_attention_routine=True?

Overall, these suggestions are beneficial. And I'll reply to the other comments you left soon.

Copy link

mergify bot commented Jun 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @fhl2000.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jun 27, 2025
fhl2000 added 3 commits June 28, 2025 06:06
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com>
@fhl2000 fhl2000 requested a review from ProExpertProg June 28, 2025 08:00
@fhl2000 fhl2000 changed the title [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer; A minor bug fixed [Core][Bugfix] New way for full cudagraph, add support for FA2 and FlashInfer Jun 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants