Skip to content

Commit f075693

Browse files
[V1] address post issues related to #20059 (part 1) (#23046)
Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
1 parent f708bd4 commit f075693

File tree

13 files changed

+346
-290
lines changed

13 files changed

+346
-290
lines changed

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 1 addition & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import contextlib
44
import os
55
import weakref
6-
from dataclasses import dataclass
7-
from typing import Optional
86

97
import pytest
108

119
from tests.utils import wait_for_gpu_memory_to_clear
10+
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
1211
from vllm import LLM, SamplingParams
1312
from vllm.config import CompilationConfig
1413
from vllm.platforms import current_platform
@@ -33,89 +32,6 @@ def temporary_environ(env_vars):
3332
os.environ[k] = v
3433

3534

36-
@dataclass
37-
class BackendConfig:
38-
name: str
39-
env_vars: dict
40-
comp_config: dict
41-
specific_gpu_arch: Optional[tuple] = None
42-
43-
44-
# Define all backend configurations of full cudagraph to be tested
45-
backend_configs = {
46-
# FA3 on Hopper
47-
"FA3":
48-
BackendConfig(name="FA3",
49-
env_vars={
50-
"VLLM_FLASH_ATTN_VERSION": "3",
51-
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
52-
},
53-
comp_config={
54-
"cudagraph_mode": "FULL",
55-
},
56-
specific_gpu_arch=(9, 0)),
57-
# FlashMLA on Hopper
58-
"FlashMLA":
59-
BackendConfig(name="FlashMLA",
60-
env_vars={
61-
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
62-
},
63-
comp_config={
64-
"cudagraph_mode": "FULL_AND_PIECEWISE",
65-
},
66-
specific_gpu_arch=(9, 0)),
67-
# FlashAttention MLA on Hopper
68-
"FlashAttentionMLA":
69-
BackendConfig(name="FlashAttentionMLA",
70-
env_vars={
71-
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
72-
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
73-
},
74-
comp_config={
75-
"cudagraph_mode": "FULL_DECODE_ONLY",
76-
},
77-
specific_gpu_arch=(9, 0)),
78-
# Cutlass MLA on Blackwell
79-
"CutlassMLA":
80-
BackendConfig(
81-
name="CutlassMLA",
82-
env_vars={
83-
"VLLM_USE_V1": "1",
84-
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
85-
"FORCE_NUM_KV_SPLITS":
86-
"1", # TODO: remove this when hang issue is fixed
87-
},
88-
comp_config={
89-
"cudagraph_mode": "FULL_AND_PIECEWISE",
90-
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
91-
},
92-
specific_gpu_arch=(10, 0)),
93-
# FA2
94-
"FA2":
95-
BackendConfig(name="FA2",
96-
env_vars={
97-
"VLLM_FLASH_ATTN_VERSION": "2",
98-
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
99-
},
100-
comp_config={
101-
"cudagraph_mode": "FULL",
102-
}),
103-
# Triton Attention
104-
"TritonAttn":
105-
BackendConfig(name="TritonAttn",
106-
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
107-
comp_config={
108-
"cudagraph_mode": "FULL",
109-
}),
110-
# FlashInfer
111-
"FlashInfer":
112-
BackendConfig(name="FlashInfer",
113-
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
114-
comp_config={
115-
"cudagraph_mode": "FULL_AND_PIECEWISE",
116-
}),
117-
}
118-
11935
test_params_full_cudagraph = []
12036

12137
# deepseek-ai/DeepSeek-V2-Lite with MLA

tests/compile/test_config.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import vllm
66
from vllm.compilation.counter import compilation_counter
7-
from vllm.config import CompilationConfig, VllmConfig
7+
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
88
from vllm.utils import _is_torch_equal_or_newer
99

1010

@@ -106,7 +106,6 @@ def test_dynamo_as_is(vllm_runner, monkeypatch):
106106
def test_no_compilation(vllm_runner, monkeypatch):
107107
# Disable multiprocessing so that the counter is in the same process
108108
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
109-
110109
with (
111110
compilation_counter.expect(num_graphs_seen=0,
112111
dynamo_as_is_count=0),
@@ -131,3 +130,67 @@ def test_enforce_eager(vllm_runner, monkeypatch):
131130
enforce_eager=True,
132131
gpu_memory_utilization=0.4) as _):
133132
pass
133+
134+
135+
def test_splitting_ops_dynamic():
136+
# Default config
137+
config = VllmConfig()
138+
assert config.compilation_config.cudagraph_mode == \
139+
CUDAGraphMode.FULL_AND_PIECEWISE
140+
assert config.compilation_config.splitting_ops_contain_attention()
141+
142+
# When use_inductor_graph_partition=True
143+
if _is_torch_equal_or_newer('2.9.0.dev'):
144+
# inductor graph partition is only available in PyTorch 2.9+.
145+
# this is a fast config check so we are not using pytest.skip.
146+
config = VllmConfig(compilation_config=CompilationConfig(
147+
use_inductor_graph_partition=True,
148+
splitting_ops=["silly_attention"]))
149+
# should ignore splitting_ops
150+
assert config.compilation_config.splitting_ops == []
151+
152+
# When attn_fusion pass enabled.
153+
config = VllmConfig(compilation_config=CompilationConfig(
154+
pass_config={
155+
"enable_attn_fusion": True,
156+
"enable_noop": True
157+
},
158+
custom_ops=["+quant_fp8"],
159+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
160+
))
161+
assert config.compilation_config.splitting_ops == []
162+
# cudagraph mode also fall back to FULL
163+
assert config.compilation_config.cudagraph_mode == \
164+
CUDAGraphMode.FULL
165+
166+
# splitting_ops can not contain attention ops when attn_fusion
167+
# pass enabled.
168+
with pytest.raises(AssertionError):
169+
config = VllmConfig(compilation_config=CompilationConfig(
170+
pass_config={
171+
"enable_attn_fusion": True,
172+
"enable_noop": True
173+
},
174+
custom_ops=["+quant_fp8"],
175+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
176+
# work around for accessing all attntion ops
177+
splitting_ops=CompilationConfig()._attention_ops,
178+
))
179+
180+
# When both use_inductor_graph_partition and attn_fusion pass enabled.
181+
if _is_torch_equal_or_newer('2.9.0.dev'):
182+
config = VllmConfig(compilation_config=CompilationConfig(
183+
use_inductor_graph_partition=True,
184+
pass_config={
185+
"enable_attn_fusion": True,
186+
"enable_noop": True
187+
},
188+
custom_ops=["+quant_fp8"],
189+
cudagraph_mode=CUDAGraphMode.PIECEWISE,
190+
))
191+
assert config.compilation_config.splitting_ops == []
192+
# enable_attn_fusion is directly support under
193+
# use_inductor_graph_partition=True, and cudagraph_mode
194+
# is unchanged.
195+
assert config.compilation_config.cudagraph_mode == \
196+
CUDAGraphMode.PIECEWISE

tests/v1/attention/utils.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Utility functions for attention-related v1 tests."""
44

55
from dataclasses import dataclass
6-
from typing import Union
6+
from typing import Optional, Union
77

88
import pytest
99
import torch
@@ -260,3 +260,88 @@ def create_dummy_kv_cache(block_size: int,
260260
dtype=dtype,
261261
device=device)
262262
return kv_cache
263+
264+
265+
@dataclass
266+
class BackendConfig:
267+
name: str
268+
env_vars: dict
269+
comp_config: dict # compilation config
270+
specific_gpu_arch: Optional[tuple] = None
271+
272+
273+
# Define all backend configurations of full cudagraph to be tested
274+
full_cg_backend_configs = {
275+
# FA3 on Hopper
276+
"FA3":
277+
BackendConfig(name="FA3",
278+
env_vars={
279+
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
280+
"VLLM_FLASH_ATTN_VERSION": "3",
281+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
282+
},
283+
comp_config={
284+
"cudagraph_mode": "FULL",
285+
},
286+
specific_gpu_arch=(9, 0)),
287+
# FlashMLA on Hopper
288+
"FlashMLA":
289+
BackendConfig(name="FlashMLA",
290+
env_vars={
291+
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
292+
},
293+
comp_config={
294+
"cudagraph_mode": "FULL_AND_PIECEWISE",
295+
},
296+
specific_gpu_arch=(9, 0)),
297+
# Cutlass MLA on Blackwell
298+
"CutlassMLA":
299+
BackendConfig(
300+
name="CutlassMLA",
301+
env_vars={
302+
"VLLM_USE_V1": "1",
303+
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
304+
"FORCE_NUM_KV_SPLITS":
305+
"1", # TODO: remove this when hang issue is fixed
306+
},
307+
comp_config={
308+
"cudagraph_mode": "FULL_AND_PIECEWISE",
309+
},
310+
specific_gpu_arch=(10, 0)),
311+
# FlashAttention MLA on Hopper
312+
"FlashAttentionMLA":
313+
BackendConfig(name="FlashAttentionMLA",
314+
env_vars={
315+
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
316+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
317+
},
318+
comp_config={
319+
"cudagraph_mode": "FULL_DECODE_ONLY",
320+
},
321+
specific_gpu_arch=(9, 0)),
322+
# FA2
323+
"FA2":
324+
BackendConfig(name="FA2",
325+
env_vars={
326+
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
327+
"VLLM_FLASH_ATTN_VERSION": "2",
328+
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
329+
},
330+
comp_config={
331+
"cudagraph_mode": "FULL_AND_PIECEWISE",
332+
}),
333+
# Triton Attention
334+
"TritonAttn":
335+
BackendConfig(name="TritonAttn",
336+
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
337+
comp_config={
338+
"cudagraph_mode": "FULL_AND_PIECEWISE",
339+
}),
340+
# FlashInfer
341+
"FlashInfer":
342+
BackendConfig(name="FlashInfer",
343+
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
344+
comp_config={
345+
"cudagraph_mode": "FULL_AND_PIECEWISE",
346+
}),
347+
}

tests/v1/cudagraph/test_cudagraph_dispatch.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -45,39 +45,22 @@ def _create_vllm_config(compilation_config: CompilationConfig,
4545
class TestCudagraphDispatcher:
4646

4747
@pytest.mark.parametrize(
48-
"params",
48+
"case_id,cudagraph_mode_str,compilation_level",
4949
[
5050
# Test case 0: Full CG for mixed batches, no separate routine
51-
{
52-
"case_id": 0,
53-
"cudagraph_mode": "FULL",
54-
"compilation_level": CompilationLevel.NO_COMPILATION,
55-
},
51+
(0, "FULL", CompilationLevel.NO_COMPILATION),
5652
# Test case 1: Full CG for uniform batches, piecewise for mixed
57-
{
58-
"case_id": 1,
59-
"cudagraph_mode": "FULL_AND_PIECEWISE",
60-
"compilation_level": CompilationLevel.PIECEWISE,
61-
},
53+
(1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
6254
# Test case 2: Full CG for uniform batches, no CG for mixed
63-
{
64-
"case_id": 2,
65-
"cudagraph_mode": "FULL_DECODE_ONLY",
66-
"compilation_level": CompilationLevel.NO_COMPILATION,
67-
},
55+
(2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
6856
# Test case 3: Piecewise for all
69-
{
70-
"case_id": 3,
71-
"cudagraph_mode": "PIECEWISE",
72-
"compilation_level": CompilationLevel.PIECEWISE,
73-
},
57+
(3, "PIECEWISE", CompilationLevel.PIECEWISE),
7458
])
75-
def test_dispatcher(self, params):
59+
def test_dispatcher(self, cudagraph_mode_str, compilation_level):
7660
# Setup dispatcher
77-
comp_config = CompilationConfig(
78-
cudagraph_mode=params["cudagraph_mode"],
79-
level=params["compilation_level"],
80-
cudagraph_capture_sizes=[1, 8])
61+
comp_config = CompilationConfig(cudagraph_mode=cudagraph_mode_str,
62+
level=compilation_level,
63+
cudagraph_capture_sizes=[1, 8])
8164

8265
config = _create_vllm_config(comp_config, max_num_seqs=8)
8366
dispatcher = CudagraphDispatcher(config)
@@ -86,11 +69,11 @@ def test_dispatcher(self, params):
8669
uniform_decode_query_len=1)
8770

8871
# Verify the key is initialized correctly
89-
if params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
72+
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
9073
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
9174
else:
9275
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
93-
if params["cudagraph_mode"] not in ["NONE", "PIECEWISE"]:
76+
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
9477
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
9578
else:
9679
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
@@ -99,10 +82,10 @@ def test_dispatcher(self, params):
9982
# 1. non-uniform batch, size in cudagraph size list
10083
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
10184
rt_mode, key = dispatcher.dispatch(desc_full_exact)
102-
if params["cudagraph_mode"] == "FULL":
85+
if cudagraph_mode_str == "FULL":
10386
assert rt_mode == CUDAGraphMode.FULL
10487
assert key == desc_full_exact
105-
elif params["cudagraph_mode"] in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
88+
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
10689
assert rt_mode == CUDAGraphMode.PIECEWISE
10790
assert key == desc_full_exact
10891
else:
@@ -111,15 +94,13 @@ def test_dispatcher(self, params):
11194
# 2. uniform decode batch, size in cudagraph size list
11295
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
11396
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
114-
if params["cudagraph_mode"] == "FULL":
97+
if cudagraph_mode_str == "FULL":
11598
assert rt_mode == CUDAGraphMode.FULL
11699
assert key == desc_uniform_exact.non_uniform
117-
elif params["cudagraph_mode"] in [
118-
"FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"
119-
]:
100+
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
120101
assert rt_mode == CUDAGraphMode.FULL
121102
assert key == desc_uniform_exact
122-
elif params["cudagraph_mode"] == "PIECEWISE":
103+
elif cudagraph_mode_str == "PIECEWISE":
123104
assert rt_mode == CUDAGraphMode.PIECEWISE
124105
assert key == desc_uniform_exact.non_uniform
125106
else:
@@ -131,6 +112,16 @@ def test_dispatcher(self, params):
131112
assert rt_mode == CUDAGraphMode.NONE
132113
assert key is None
133114

115+
# 4. Cascade attention should have a fall back mode
116+
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
117+
rt_mode, key = dispatcher.dispatch(desc_full_exact,
118+
use_cascade_attn=True)
119+
if "PIECEWISE" in cudagraph_mode_str: # string contains check
120+
assert rt_mode == CUDAGraphMode.PIECEWISE
121+
assert key == desc_full_exact.non_uniform
122+
else:
123+
assert rt_mode == CUDAGraphMode.NONE
124+
134125

135126
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
136127
class TestCUDAGraphWrapper:

0 commit comments

Comments
 (0)