33import contextlib
44import os
55import weakref
6- from contextlib import ExitStack
6+ from dataclasses import dataclass
7+ from typing import Optional
78
89import pytest
910
@@ -32,69 +33,133 @@ def temporary_environ(env_vars):
3233 os .environ [k ] = v
3334
3435
35- @pytest .fixture (scope = "class" )
36- def llm_pair (request ):
37- model = request .param
38-
39- with temporary_environ ({
40- "VLLM_USE_V1" : "1" ,
41- "VLLM_FLASH_ATTN_VERSION" : "3"
42- }):
43- full = LLM (
44- model = model ,
45- gpu_memory_utilization = 0.45 ,
46- trust_remote_code = True ,
47- max_model_len = 1024 ,
48- compilation_config = CompilationConfig (full_cuda_graph = True ),
49- )
50- piecewise = LLM (
51- model = model ,
52- gpu_memory_utilization = 0.45 ,
53- trust_remote_code = True ,
54- max_model_len = 1024 ,
55- compilation_config = CompilationConfig (),
56- )
57-
58- # PyTest caches the fixture values so we use weakref.proxy to enable GC
59- yield weakref .proxy (full ), weakref .proxy (piecewise )
60- del full
61- del piecewise
62-
63- wait_for_gpu_memory_to_clear (
64- devices = [0 ],
65- threshold_ratio = 0.1 ,
66- )
67-
68-
69- @pytest .fixture (scope = "class" )
70- def cutlass_mla_llm_pair (request ):
71- model = request .param
72-
73- # force V1 engine and Cutlass MLA backend
74- with temporary_environ ({
36+ @dataclass
37+ class BackendConfig :
38+ name : str
39+ env_vars : dict
40+ comp_config : dict
41+ specific_gpu_arch : Optional [tuple ] = None
42+
43+
44+ # Define all backend configurations of full cudagraph to be tested
45+ backend_configs = {
46+ # FA3 on Hopper
47+ "FA3" :
48+ BackendConfig (name = "FA3" ,
49+ env_vars = {"VLLM_FLASH_ATTN_VERSION" : "3" },
50+ comp_config = {
51+ "cudagraph_mode" : "FULL" ,
52+ },
53+ specific_gpu_arch = (9 , 0 )),
54+ # FlashMLA on Hopper
55+ "FlashMLA" :
56+ BackendConfig (name = "FlashMLA" ,
57+ env_vars = {
58+ "VLLM_ATTENTION_BACKEND" : "FLASHMLA" ,
59+ },
60+ comp_config = {
61+ "cudagraph_mode" : "FULL_AND_PIECEWISE" ,
62+ },
63+ specific_gpu_arch = (9 , 0 )),
64+ # Cutlass MLA on Blackwell
65+ "CutlassMLA" :
66+ BackendConfig (
67+ name = "CutlassMLA" ,
68+ env_vars = {
7569 "VLLM_USE_V1" : "1" ,
7670 "VLLM_ATTENTION_BACKEND" : "CUTLASS_MLA" ,
7771 "FORCE_NUM_KV_SPLITS" :
7872 "1" , # TODO: remove this when hang issue is fixed
79- }):
73+ },
74+ comp_config = {
75+ "cudagraph_mode" : "FULL_AND_PIECEWISE" ,
76+ "cudagraph_capture_sizes" : [16 , 32 , 64 , 128 , 256 , 512 ],
77+ },
78+ specific_gpu_arch = (10 , 0 )),
79+ # FA2
80+ "FA2" :
81+ BackendConfig (name = "FA2" ,
82+ env_vars = {"VLLM_FLASH_ATTN_VERSION" : "2" },
83+ comp_config = {
84+ "cudagraph_mode" : "FULL" ,
85+ }),
86+ # Triton Attention
87+ "TritonAttn" :
88+ BackendConfig (name = "TritonAttn" ,
89+ env_vars = {"VLLM_ATTENTION_BACKEND" : "TRITON_ATTN_VLLM_V1" },
90+ comp_config = {
91+ "cudagraph_mode" : "FULL" ,
92+ }),
93+ # FlashInfer
94+ "FlashInfer" :
95+ BackendConfig (name = "FlashInfer" ,
96+ env_vars = {"VLLM_ATTENTION_BACKEND" : "FLASHINFER" },
97+ comp_config = {
98+ "cudagraph_mode" : "FULL_AND_PIECEWISE" ,
99+ }),
100+ }
101+
102+ test_params_full_cudagraph = []
103+
104+ # deepseek-ai/DeepSeek-V2-Lite with MLA
105+ MLA_backends = ["FlashMLA" , "CutlassMLA" ]
106+ for mla_backend in MLA_backends :
107+ test_params_full_cudagraph .append (
108+ pytest .param (
109+ ("deepseek-ai/DeepSeek-V2-Lite" , backend_configs [mla_backend ])))
110+
111+ # Qwen/Qwen2-1.5B-Instruct with other backends
112+ other_backend_configs = [
113+ backend_configs [c ] for c in backend_configs if c not in MLA_backends
114+ ]
115+ for backend_config in other_backend_configs :
116+ test_params_full_cudagraph .append (
117+ pytest .param (("Qwen/Qwen2-1.5B-Instruct" , backend_config )))
118+
119+
120+ @pytest .fixture (scope = "class" )
121+ def llm_pair (request ):
122+ model , backend_config = request .param
123+
124+ # Dynamically skip test if GPU capability is not met
125+ if backend_config .specific_gpu_arch and backend_config .specific_gpu_arch \
126+ != current_platform .get_device_capability ():
127+ if backend_config .specific_gpu_arch == (9 , 0 ):
128+ pytest .skip ("Only Hopper GPUs support FA3 and FlashMLA" )
129+ elif backend_config .specific_gpu_arch == (10 , 0 ):
130+ pytest .skip ("Only Blackwell GPUs support Cutlass MLA" )
131+
132+ env_vars = {
133+ "VLLM_USE_V1" : "1" ,
134+ # Force native sampler to avoid potential nondeterminism in FlashInfer
135+ # when per-request generators are not used in V1.
136+ "VLLM_USE_FLASHINFER_SAMPLER" : "0" ,
137+ ** backend_config .env_vars ,
138+ }
139+ with temporary_environ (env_vars ):
80140 full = LLM (
81141 model = model ,
82- gpu_memory_utilization = 0.45 ,
142+ gpu_memory_utilization = 0.43 ,
83143 trust_remote_code = True ,
84144 max_model_len = 1024 ,
85- compilation_config = CompilationConfig (
86- full_cuda_graph = True ,
87- cudagraph_capture_sizes = [16 , 32 , 64 , 128 , 256 , 512 ],
88- ),
145+ max_num_seqs = 128 ,
146+ compilation_config = \
147+ CompilationConfig (** backend_config .comp_config ),
148+ generation_config = "vllm" ,
149+ seed = 42 ,
89150 )
90151 piecewise = LLM (
91152 model = model ,
92- gpu_memory_utilization = 0.45 ,
153+ gpu_memory_utilization = 0.43 ,
93154 trust_remote_code = True ,
94155 max_model_len = 1024 ,
95- compilation_config = CompilationConfig (),
156+ max_num_seqs = 128 ,
157+ compilation_config = CompilationConfig (cudagraph_mode = "PIECEWISE" ),
158+ generation_config = "vllm" ,
159+ seed = 42 ,
96160 )
97161
162+ # PyTest caches the fixture values so we use weakref.proxy to enable GC
98163 yield weakref .proxy (full ), weakref .proxy (piecewise )
99164 del full
100165 del piecewise
@@ -105,51 +170,7 @@ def cutlass_mla_llm_pair(request):
105170 )
106171
107172
108- @pytest .mark .parametrize (
109- "cutlass_mla_llm_pair" ,
110- [
111- # use an MLA model
112- "deepseek-ai/DeepSeek-V2-Lite" ,
113- ],
114- indirect = True )
115- @pytest .mark .skipif (current_platform .get_device_capability () != (10 , 0 ),
116- reason = "Only Blackwell GPUs support Cutlass MLA" )
117- class TestFullCUDAGraphCutlassMLA :
118- """
119- Validate full CUDA Graph with Cutlass MLA (decode-only capture).
120- """
121-
122- @pytest .mark .parametrize (("batch_size" , "max_tokens" ), [
123- (8 , 8 ),
124- ])
125- def test_full_cudagraph_sm100_cutlass_mla (
126- self , batch_size , max_tokens , cutlass_mla_llm_pair : tuple [LLM ,
127- LLM ]):
128- piecewise_llm , full_cudagraph_llm = cutlass_mla_llm_pair
129-
130- prompts = ["Hello, my name is" ] * batch_size
131- sampling_params = SamplingParams (temperature = 0.0 ,
132- max_tokens = max_tokens ,
133- top_p = 0.95 )
134-
135- piecewise_responses = piecewise_llm .generate (prompts , sampling_params )
136- full_responses = full_cudagraph_llm .generate (prompts , sampling_params )
137-
138- for piecewise_res , full_res in zip (piecewise_responses ,
139- full_responses ):
140- assert piecewise_res .outputs [0 ].text == full_res .outputs [0 ].text
141-
142-
143- @pytest .mark .parametrize (
144- "llm_pair" ,
145- [
146- # Model names for the llm_pair fixture
147- "deepseek-ai/DeepSeek-V2-Lite" ,
148- "Qwen/Qwen2-1.5B-Instruct"
149- ],
150- indirect = True )
151- @pytest .mark .skipif (current_platform .get_device_capability () != (9 , 0 ),
152- reason = "Only Hopper GPUs support FA3 and FlashMLA" )
173+ @pytest .mark .parametrize ("llm_pair" , test_params_full_cudagraph , indirect = True )
153174class TestFullCUDAGraph :
154175 """
155176 Use a class such that an llm pair is constructed once for all
@@ -178,55 +199,31 @@ def test_full_cudagraph(self, batch_size, max_tokens,
178199 full cudagraph compilation works for padded cases too.
179200 """
180201
181- piecewise_llm , full_cudagraph_llm = llm_pair
202+ full_cudagraph_llm , piecewise_llm = llm_pair
182203
183- prompts = ["Hello, my name is" ] * batch_size
204+ prompts = ["the quick brown fox" ] * batch_size
205+ # Use purely greedy decoding to avoid top-p truncation sensitivity
206+ # that can amplify tiny numeric differences across runtimes.
184207 sampling_params = SamplingParams (temperature = 0.0 ,
185208 max_tokens = max_tokens ,
186- top_p = 0.95 )
209+ top_p = 1.0 )
187210
188211 piecewise_responses = piecewise_llm .generate (prompts , sampling_params )
189212 full_responses = full_cudagraph_llm .generate (prompts , sampling_params )
190213
191214 # Check that all responses are the same
192215 for piecewise_res , full_res in zip (piecewise_responses ,
193216 full_responses ):
194- assert piecewise_res .outputs [0 ].text == full_res .outputs [0 ].text
195-
196-
197- @pytest .mark .parametrize (
198- "model, supported" ,
199- [
200- ("Qwen/Qwen2-1.5B-Instruct" , True ),
201- # MLA does not support capturing CUDA Graphs with size > max_num_seqs
202- ("deepseek-ai/DeepSeek-V2-Lite" , False ),
203- ])
204- @pytest .mark .skipif (current_platform .get_device_capability () != (9 , 0 ),
205- reason = "Only Hopper GPUs support FA3 and FlashMLA" )
206- def test_lower_max_num_seqs (model , supported ):
207- with temporary_environ ({
208- "VLLM_USE_V1" : "1" ,
209- "VLLM_FLASH_ATTN_VERSION" : "3"
210- }), ExitStack () as stack :
211- if not supported :
212- stack .enter_context (pytest .raises (RuntimeError ))
213-
214- llm = LLM (model = model ,
215- max_num_seqs = 256 ,
216- trust_remote_code = True ,
217- max_model_len = 1024 ,
218- compilation_config = CompilationConfig (
219- full_cuda_graph = True ,
220- cudagraph_capture_sizes = [64 , 256 , 512 ]))
221- llm .generate (["Hello, my name is" ] * 10 )
217+ assert piecewise_res .outputs [0 ].text .lower () == \
218+ full_res .outputs [0 ].text .lower ()
222219
223220
224221@pytest .mark .skipif (not current_platform .is_cuda (), reason = "Skip if not cuda" )
225222def test_full_cudagraph_with_invalid_backend ():
226223 with temporary_environ ({
227224 "VLLM_USE_V1" : "1" ,
228- "VLLM_FLASH_ATTN_VERSION" :
229- "2" #FA2 not supported with full_cuda_graph
225+ "VLLM_ATTENTION_BACKEND" : "FLEX_ATTENTION"
226+ # Flex_Attention is not supported with full cuda graph
230227 }), pytest .raises (RuntimeError ):
231228 LLM (model = "Qwen/Qwen2-1.5B-Instruct" ,
232- compilation_config = CompilationConfig (full_cuda_graph = True ))
229+ compilation_config = CompilationConfig (cudagraph_mode = "FULL" ))
0 commit comments