Skip to content

[Bugfix] Fix FA3 full cuda graph correctness #19106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ steps:
# these tests need to be separated, cannot combine
- pytest -v -s compile/piecewise/test_simple.py
- pytest -v -s compile/piecewise/test_toy_llama.py
- pytest -v -s compile/piecewise/test_full_cudagraph.py

- label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental, amdproduction]
Expand Down
7 changes: 5 additions & 2 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform

MODEL = "Qwen/Qwen2-1.5B-Instruct"

Expand Down Expand Up @@ -37,7 +38,7 @@ def full_cudagraph_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
gpu_memory_utilization=0.3,
compilation_config=CompilationConfig(full_cuda_graph=True))


Expand All @@ -48,7 +49,7 @@ def piecewise_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
gpu_memory_utilization=0.6,
compilation_config=CompilationConfig())


Expand All @@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
return llm.generate(prompts, sampling_params)


@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
Expand Down
29 changes: 21 additions & 8 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,14 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table

if get_flash_attn_version() == 3:
self.aot_schedule = not compilation_config.full_cuda_graph
if not self.aot_schedule:
logger.warning(
"AOT Schedule is disabled when using full_cuda_graph")
else:
self.aot_schedule = False
self.aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = compilation_config.full_cuda_graph
if self.use_full_cuda_graph and not self.aot_schedule:
raise ValueError("Full CUDA graph mode requires AOT scheduling, "
"which requires FlashAttention 3.")
self.scheduler_metadata = torch.zeros(self.runner.max_num_reqs + 1,
dtype=torch.int32,
device=self.runner.device)

# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
Expand All @@ -326,7 +327,7 @@ def reorder_batch(self, input_batch: "InputBatch",
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = self.block_table
Expand Down Expand Up @@ -448,6 +449,18 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len=max_seq_len,
causal=True)

if self.use_full_cuda_graph:
assert scheduler_metadata is not None
n = scheduler_metadata.shape[0]
self.scheduler_metadata[:n].copy_(scheduler_metadata,
non_blocking=True)
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]

attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,6 +1750,11 @@ def _dummy_run(
attn_metadata: Optional[dict[str, Any]] = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len
self.seq_lens_np[num_reqs:] = 0
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]

common_attn_metadata = CommonAttentionMetadata(
Expand Down