Skip to content

Commit 547fecb

Browse files
fhl2000ProExpertProgLucasWilkinson
authored andcommitted
[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (vllm-project#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
1 parent f45390b commit 547fecb

34 files changed

+1840
-598
lines changed

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 125 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import contextlib
44
import os
55
import weakref
6-
from contextlib import ExitStack
6+
from dataclasses import dataclass
7+
from typing import Optional
78

89
import pytest
910

@@ -32,69 +33,133 @@ def temporary_environ(env_vars):
3233
os.environ[k] = v
3334

3435

35-
@pytest.fixture(scope="class")
36-
def llm_pair(request):
37-
model = request.param
38-
39-
with temporary_environ({
40-
"VLLM_USE_V1": "1",
41-
"VLLM_FLASH_ATTN_VERSION": "3"
42-
}):
43-
full = LLM(
44-
model=model,
45-
gpu_memory_utilization=0.45,
46-
trust_remote_code=True,
47-
max_model_len=1024,
48-
compilation_config=CompilationConfig(full_cuda_graph=True),
49-
)
50-
piecewise = LLM(
51-
model=model,
52-
gpu_memory_utilization=0.45,
53-
trust_remote_code=True,
54-
max_model_len=1024,
55-
compilation_config=CompilationConfig(),
56-
)
57-
58-
# PyTest caches the fixture values so we use weakref.proxy to enable GC
59-
yield weakref.proxy(full), weakref.proxy(piecewise)
60-
del full
61-
del piecewise
62-
63-
wait_for_gpu_memory_to_clear(
64-
devices=[0],
65-
threshold_ratio=0.1,
66-
)
67-
68-
69-
@pytest.fixture(scope="class")
70-
def cutlass_mla_llm_pair(request):
71-
model = request.param
72-
73-
# force V1 engine and Cutlass MLA backend
74-
with temporary_environ({
36+
@dataclass
37+
class BackendConfig:
38+
name: str
39+
env_vars: dict
40+
comp_config: dict
41+
specific_gpu_arch: Optional[tuple] = None
42+
43+
44+
# Define all backend configurations of full cudagraph to be tested
45+
backend_configs = {
46+
# FA3 on Hopper
47+
"FA3":
48+
BackendConfig(name="FA3",
49+
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
50+
comp_config={
51+
"cudagraph_mode": "FULL",
52+
},
53+
specific_gpu_arch=(9, 0)),
54+
# FlashMLA on Hopper
55+
"FlashMLA":
56+
BackendConfig(name="FlashMLA",
57+
env_vars={
58+
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
59+
},
60+
comp_config={
61+
"cudagraph_mode": "FULL_AND_PIECEWISE",
62+
},
63+
specific_gpu_arch=(9, 0)),
64+
# Cutlass MLA on Blackwell
65+
"CutlassMLA":
66+
BackendConfig(
67+
name="CutlassMLA",
68+
env_vars={
7569
"VLLM_USE_V1": "1",
7670
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
7771
"FORCE_NUM_KV_SPLITS":
7872
"1", # TODO: remove this when hang issue is fixed
79-
}):
73+
},
74+
comp_config={
75+
"cudagraph_mode": "FULL_AND_PIECEWISE",
76+
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
77+
},
78+
specific_gpu_arch=(10, 0)),
79+
# FA2
80+
"FA2":
81+
BackendConfig(name="FA2",
82+
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
83+
comp_config={
84+
"cudagraph_mode": "FULL",
85+
}),
86+
# Triton Attention
87+
"TritonAttn":
88+
BackendConfig(name="TritonAttn",
89+
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
90+
comp_config={
91+
"cudagraph_mode": "FULL",
92+
}),
93+
# FlashInfer
94+
"FlashInfer":
95+
BackendConfig(name="FlashInfer",
96+
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
97+
comp_config={
98+
"cudagraph_mode": "FULL_AND_PIECEWISE",
99+
}),
100+
}
101+
102+
test_params_full_cudagraph = []
103+
104+
# deepseek-ai/DeepSeek-V2-Lite with MLA
105+
MLA_backends = ["FlashMLA", "CutlassMLA"]
106+
for mla_backend in MLA_backends:
107+
test_params_full_cudagraph.append(
108+
pytest.param(
109+
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
110+
111+
# Qwen/Qwen2-1.5B-Instruct with other backends
112+
other_backend_configs = [
113+
backend_configs[c] for c in backend_configs if c not in MLA_backends
114+
]
115+
for backend_config in other_backend_configs:
116+
test_params_full_cudagraph.append(
117+
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
118+
119+
120+
@pytest.fixture(scope="class")
121+
def llm_pair(request):
122+
model, backend_config = request.param
123+
124+
# Dynamically skip test if GPU capability is not met
125+
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
126+
!= current_platform.get_device_capability():
127+
if backend_config.specific_gpu_arch == (9, 0):
128+
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
129+
elif backend_config.specific_gpu_arch == (10, 0):
130+
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
131+
132+
env_vars = {
133+
"VLLM_USE_V1": "1",
134+
# Force native sampler to avoid potential nondeterminism in FlashInfer
135+
# when per-request generators are not used in V1.
136+
"VLLM_USE_FLASHINFER_SAMPLER": "0",
137+
**backend_config.env_vars,
138+
}
139+
with temporary_environ(env_vars):
80140
full = LLM(
81141
model=model,
82-
gpu_memory_utilization=0.45,
142+
gpu_memory_utilization=0.43,
83143
trust_remote_code=True,
84144
max_model_len=1024,
85-
compilation_config=CompilationConfig(
86-
full_cuda_graph=True,
87-
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
88-
),
145+
max_num_seqs=128,
146+
compilation_config=\
147+
CompilationConfig(**backend_config.comp_config),
148+
generation_config="vllm",
149+
seed=42,
89150
)
90151
piecewise = LLM(
91152
model=model,
92-
gpu_memory_utilization=0.45,
153+
gpu_memory_utilization=0.43,
93154
trust_remote_code=True,
94155
max_model_len=1024,
95-
compilation_config=CompilationConfig(),
156+
max_num_seqs=128,
157+
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
158+
generation_config="vllm",
159+
seed=42,
96160
)
97161

162+
# PyTest caches the fixture values so we use weakref.proxy to enable GC
98163
yield weakref.proxy(full), weakref.proxy(piecewise)
99164
del full
100165
del piecewise
@@ -105,51 +170,7 @@ def cutlass_mla_llm_pair(request):
105170
)
106171

107172

108-
@pytest.mark.parametrize(
109-
"cutlass_mla_llm_pair",
110-
[
111-
# use an MLA model
112-
"deepseek-ai/DeepSeek-V2-Lite",
113-
],
114-
indirect=True)
115-
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
116-
reason="Only Blackwell GPUs support Cutlass MLA")
117-
class TestFullCUDAGraphCutlassMLA:
118-
"""
119-
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
120-
"""
121-
122-
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
123-
(8, 8),
124-
])
125-
def test_full_cudagraph_sm100_cutlass_mla(
126-
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
127-
LLM]):
128-
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
129-
130-
prompts = ["Hello, my name is"] * batch_size
131-
sampling_params = SamplingParams(temperature=0.0,
132-
max_tokens=max_tokens,
133-
top_p=0.95)
134-
135-
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
136-
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
137-
138-
for piecewise_res, full_res in zip(piecewise_responses,
139-
full_responses):
140-
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
141-
142-
143-
@pytest.mark.parametrize(
144-
"llm_pair",
145-
[
146-
# Model names for the llm_pair fixture
147-
"deepseek-ai/DeepSeek-V2-Lite",
148-
"Qwen/Qwen2-1.5B-Instruct"
149-
],
150-
indirect=True)
151-
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
152-
reason="Only Hopper GPUs support FA3 and FlashMLA")
173+
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
153174
class TestFullCUDAGraph:
154175
"""
155176
Use a class such that an llm pair is constructed once for all
@@ -178,55 +199,31 @@ def test_full_cudagraph(self, batch_size, max_tokens,
178199
full cudagraph compilation works for padded cases too.
179200
"""
180201

181-
piecewise_llm, full_cudagraph_llm = llm_pair
202+
full_cudagraph_llm, piecewise_llm = llm_pair
182203

183-
prompts = ["Hello, my name is"] * batch_size
204+
prompts = ["the quick brown fox"] * batch_size
205+
# Use purely greedy decoding to avoid top-p truncation sensitivity
206+
# that can amplify tiny numeric differences across runtimes.
184207
sampling_params = SamplingParams(temperature=0.0,
185208
max_tokens=max_tokens,
186-
top_p=0.95)
209+
top_p=1.0)
187210

188211
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
189212
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
190213

191214
# Check that all responses are the same
192215
for piecewise_res, full_res in zip(piecewise_responses,
193216
full_responses):
194-
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
195-
196-
197-
@pytest.mark.parametrize(
198-
"model, supported",
199-
[
200-
("Qwen/Qwen2-1.5B-Instruct", True),
201-
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
202-
("deepseek-ai/DeepSeek-V2-Lite", False),
203-
])
204-
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
205-
reason="Only Hopper GPUs support FA3 and FlashMLA")
206-
def test_lower_max_num_seqs(model, supported):
207-
with temporary_environ({
208-
"VLLM_USE_V1": "1",
209-
"VLLM_FLASH_ATTN_VERSION": "3"
210-
}), ExitStack() as stack:
211-
if not supported:
212-
stack.enter_context(pytest.raises(RuntimeError))
213-
214-
llm = LLM(model=model,
215-
max_num_seqs=256,
216-
trust_remote_code=True,
217-
max_model_len=1024,
218-
compilation_config=CompilationConfig(
219-
full_cuda_graph=True,
220-
cudagraph_capture_sizes=[64, 256, 512]))
221-
llm.generate(["Hello, my name is"] * 10)
217+
assert piecewise_res.outputs[0].text.lower() == \
218+
full_res.outputs[0].text.lower()
222219

223220

224221
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
225222
def test_full_cudagraph_with_invalid_backend():
226223
with temporary_environ({
227224
"VLLM_USE_V1": "1",
228-
"VLLM_FLASH_ATTN_VERSION":
229-
"2" #FA2 not supported with full_cuda_graph
225+
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
226+
# Flex_Attention is not supported with full cuda graph
230227
}), pytest.raises(RuntimeError):
231228
LLM(model="Qwen/Qwen2-1.5B-Instruct",
232-
compilation_config=CompilationConfig(full_cuda_graph=True))
229+
compilation_config=CompilationConfig(cudagraph_mode="FULL"))

tests/compile/piecewise/test_simple.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.decorators import support_torch_compile
14-
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
15-
set_current_vllm_config)
14+
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
15+
VllmConfig, set_current_vllm_config)
1616
from vllm.envs import VLLM_USE_V1
17-
from vllm.forward_context import set_forward_context
17+
from vllm.forward_context import BatchDescriptor, set_forward_context
1818
from vllm.utils import direct_register_custom_op
1919

2020
global_counter = 0
@@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
101101
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
102102
num_cudagraph_captured=
103103
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
104-
), set_forward_context({}, vllm_config=vllm_config):
105-
104+
), set_forward_context(None,
105+
vllm_config=vllm_config): # background context
106+
# warm up with background context
106107
model(inputs)
107108

108-
model(torch.randn(2).cuda())
109-
model(torch.randn(1).cuda())
109+
# capturing/replaying should under context of cudagraph dispatching
110+
with set_forward_context(
111+
None,
112+
vllm_config=vllm_config,
113+
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
114+
batch_descriptor=BatchDescriptor(num_tokens=2, )):
115+
model(torch.randn(2).cuda())
116+
with set_forward_context(
117+
None,
118+
vllm_config=vllm_config,
119+
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
120+
batch_descriptor=BatchDescriptor(num_tokens=1, )):
121+
model(torch.randn(1).cuda())
110122

111123
input = torch.zeros(2).cuda()
112124
global global_counter
113125
global_counter = 0
114-
output = model(input)
126+
with set_forward_context(
127+
None,
128+
vllm_config=vllm_config,
129+
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
130+
batch_descriptor=BatchDescriptor(num_tokens=2, )):
131+
output = model(input)
115132
assert global_counter == 2
116133
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))

0 commit comments

Comments
 (0)