Skip to content

Commit b124e10

Browse files
authored
[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 41aa578 commit b124e10

File tree

4 files changed

+32
-10
lines changed

4 files changed

+32
-10
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ steps:
320320
# these tests need to be separated, cannot combine
321321
- pytest -v -s compile/piecewise/test_simple.py
322322
- pytest -v -s compile/piecewise/test_toy_llama.py
323+
- pytest -v -s compile/piecewise/test_full_cudagraph.py
323324

324325
- label: PyTorch Fullgraph Test # 18min
325326
mirror_hardwares: [amdexperimental, amdproduction]

tests/compile/piecewise/test_full_cudagraph.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from vllm import LLM, SamplingParams
99
from vllm.config import CompilationConfig
10+
from vllm.platforms import current_platform
1011

1112
MODEL = "Qwen/Qwen2-1.5B-Instruct"
1213

@@ -37,7 +38,7 @@ def full_cudagraph_llm():
3738
"VLLM_FLASH_ATTN_VERSION": "3"
3839
}):
3940
return LLM(model=MODEL,
40-
gpu_memory_utilization=0.2,
41+
gpu_memory_utilization=0.3,
4142
compilation_config=CompilationConfig(full_cuda_graph=True))
4243

4344

@@ -48,7 +49,7 @@ def piecewise_llm():
4849
"VLLM_FLASH_ATTN_VERSION": "3"
4950
}):
5051
return LLM(model=MODEL,
51-
gpu_memory_utilization=0.5,
52+
gpu_memory_utilization=0.6,
5253
compilation_config=CompilationConfig())
5354

5455

@@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
6162
return llm.generate(prompts, sampling_params)
6263

6364

65+
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
66+
reason="Only Hopper GPUs support FlashAttention 3")
6467
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
6568
(16, 10), (25, 10),
6669
(32, 10), (45, 10),

vllm/v1/attention/backends/flash_attn.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,13 +307,14 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
307307
self.kv_cache_spec = kv_cache_spec
308308
self.block_table = block_table
309309

310-
if get_flash_attn_version() == 3:
311-
self.aot_schedule = not compilation_config.full_cuda_graph
312-
if not self.aot_schedule:
313-
logger.warning(
314-
"AOT Schedule is disabled when using full_cuda_graph")
315-
else:
316-
self.aot_schedule = False
310+
self.aot_schedule = (get_flash_attn_version() == 3)
311+
self.use_full_cuda_graph = compilation_config.full_cuda_graph
312+
if self.use_full_cuda_graph and not self.aot_schedule:
313+
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
314+
"which requires FlashAttention 3.")
315+
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
316+
dtype=torch.int32,
317+
device=self.runner.device)
317318

318319
# Sliding window size to be used with the AOT scheduler will be
319320
# populated on first build() call.
@@ -326,7 +327,7 @@ def reorder_batch(self, input_batch: "InputBatch",
326327
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
327328
common_prefix_len: int,
328329
common_attn_metadata: CommonAttentionMetadata):
329-
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
330+
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
330331
query_start_loc = common_attn_metadata.query_start_loc
331332
seq_lens = common_attn_metadata.seq_lens
332333
block_table = self.block_table
@@ -448,6 +449,18 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
448449
max_seq_len=max_seq_len,
449450
causal=True)
450451

452+
if self.use_full_cuda_graph:
453+
assert scheduler_metadata is not None
454+
n = scheduler_metadata.shape[0]
455+
self.scheduler_metadata[:n].copy_(scheduler_metadata,
456+
non_blocking=True)
457+
# NOTE(woosuk): We should zero out the rest of the scheduler
458+
# metadata to guarantee the correctness. Otherwise, some thread
459+
# blocks may use the invalid scheduler metadata and overwrite the
460+
# output buffer.
461+
self.scheduler_metadata[n:] = 0
462+
scheduler_metadata = self.scheduler_metadata[:n]
463+
451464
attn_metadata = FlashAttentionMetadata(
452465
num_actual_tokens=num_actual_tokens,
453466
max_query_len=max_query_len,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,11 @@ def _dummy_run(
17501750
attn_metadata: Optional[dict[str, Any]] = None
17511751
else:
17521752
query_start_loc = self.query_start_loc[:num_reqs + 1]
1753+
# Make sure max_model_len is used at the graph capture time.
1754+
self.seq_lens_np[:num_reqs] = self.max_model_len
1755+
self.seq_lens_np[num_reqs:] = 0
1756+
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
1757+
non_blocking=True)
17531758
seq_lens = self.seq_lens[:num_reqs]
17541759

17551760
common_attn_metadata = CommonAttentionMetadata(

0 commit comments

Comments
 (0)