Skip to content

[CUDA] Enable full cudagraph for FlashMLA #18581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 107 additions & 51 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os
import weakref
from contextlib import ExitStack

import pytest

from tests.utils import wait_for_gpu_memory_to_clear
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform

MODEL = "Qwen/Qwen2-1.5B-Instruct"


@contextlib.contextmanager
def temporary_environ(env_vars):
Expand All @@ -31,64 +32,119 @@ def temporary_environ(env_vars):
os.environ[k] = v


@pytest.fixture(scope="module")
def full_cudagraph_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.3,
compilation_config=CompilationConfig(full_cuda_graph=True))

@pytest.fixture(scope="class")
def llm_pair(request):
model = request.param

@pytest.fixture(scope="module")
def piecewise_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.6,
compilation_config=CompilationConfig())


def generate_text(llm: LLM, batch_size: int, max_tokens: int):
prompts = ["Hi my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)

return llm.generate(prompts, sampling_params)


full = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(full_cuda_graph=True),
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
)

# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise

wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)


@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
reason="Only Hopper GPUs support FA3 and FlashMLA")
class TestFullCUDAGraph:
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across various test cases.
Use a class such that an llm pair is constructed once for all
batch_size/max_tokens combinations and released immediately after.

Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too.
Module-scope fixtures would stick around the whole time,
meaning there would be multiple LLM instances hogging memory simultaneously.
"""
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)

# Check that all responses are the same
for i in range(len(piecewise_responses)):
assert piecewise_responses[i].outputs[
0].text == full_cudagraph_responses[i].outputs[0].text
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(1, 10),
(7, 10),
(16, 10),
(25, 10),
(32, 10),
(45, 10),
(64, 10),
(123, 10),
(8, 5),
(8, 30),
])
def test_full_cudagraph(self, batch_size, max_tokens,
llm_pair: tuple[LLM, LLM]):
"""
Test various batch sizes and max_tokens to ensure that the
full cudagraph compilation works for padded cases too.
"""

piecewise_llm, full_cudagraph_llm = llm_pair

prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)

piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text


@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))

llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)


def test_full_cudagraph_with_invalid_backend():
Expand All @@ -97,5 +153,5 @@ def test_full_cudagraph_with_invalid_backend():
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError):
LLM(model=MODEL,
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True))
16 changes: 5 additions & 11 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""

import pytest
import torch
from torch import nn
from torch.library import Library
Expand All @@ -14,6 +14,7 @@
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down Expand Up @@ -76,7 +77,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def _test_simple_piecewise_compile(*, use_inductor):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_simple_piecewise_compile(use_inductor):
assert VLLM_USE_V1

vllm_config = VllmConfig(compilation_config=CompilationConfig(
Expand All @@ -99,7 +101,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
), set_forward_context({}, vllm_config=vllm_config):

model(inputs)

Expand All @@ -112,11 +114,3 @@ def _test_simple_piecewise_compile(*, use_inductor):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))


def test_simple_piecewise_compile_inductor():
_test_simple_piecewise_compile(use_inductor=True)


def test_simple_piecewise_compile_no_inductor():
_test_simple_piecewise_compile(use_inductor=False)
45 changes: 21 additions & 24 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from typing import Any, Optional

import pytest
import torch
from torch import nn
from torch.library import Library
Expand All @@ -19,6 +20,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.utils import direct_register_custom_op

# create a library to hold the custom op
Expand Down Expand Up @@ -285,29 +287,32 @@ def run_model(llama_config,
vllm_config=vllm_config,
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()
with set_forward_context({}, vllm_config=vllm_config):
B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()

model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])

input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])

output = output.cpu()
output = output.cpu()

if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], positions[:2],
llama_config).cpu()
if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2],
positions[:2],
llama_config).cpu()

assert torch.allclose(output, expected_output)
else:
return output.cpu()
assert torch.allclose(output, expected_output)
else:
return output.cpu()


def _test_toy_llama(*, use_inductor):
@pytest.mark.parametrize("use_inductor", [True, False])
def test_toy_llama(use_inductor: bool):
# compare output with and without piecewise compilation

llama_config = LlamaConfig(hidden_size=128,
Expand Down Expand Up @@ -379,14 +384,6 @@ def _test_toy_llama(*, use_inductor):
assert torch.allclose(outputs[0], outputs[i])


def test_toy_llama_inductor():
_test_toy_llama(use_inductor=True)


def test_toy_no_inductor():
_test_toy_llama(use_inductor=False)


@torch.inference_mode
def benchmark():
from triton.testing import do_bench
Expand Down
30 changes: 21 additions & 9 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,42 +667,54 @@ def get_physical_device_indices(devices):


@_nvml()
def wait_for_gpu_memory_to_clear(devices: list[int],
threshold_bytes: int,
def wait_for_gpu_memory_to_clear(*,
devices: list[int],
threshold_bytes: Optional[int] = None,
threshold_ratio: Optional[float] = None,
timeout_s: float = 120) -> None:
assert threshold_bytes is not None or threshold_ratio is not None
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
devices = get_physical_device_indices(devices)
start_time = time.time()
while True:
output: dict[int, str] = {}
output_raw: dict[int, float] = {}
output_raw: dict[int, tuple[float, float]] = {}
for device in devices:
if current_platform.is_rocm():
dev_handle = amdsmi_get_processor_handles()[device]
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
gb_used = mem_info["vram_used"] / 2**10
gb_total = mem_info["vram_total"] / 2**10
else:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
gb_total = mem_info.total / 2**30
output_raw[device] = (gb_used, gb_total)
output[device] = f'{gb_used:.02f}/{gb_total:.02f}'

print('gpu memory used (GB): ', end='')
print('gpu memory used/total (GiB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')

if threshold_bytes is not None:
is_free = lambda used, total: used <= threshold_bytes / 2**30
threshold = f"{threshold_bytes/2**30} GiB"
else:
is_free = lambda used, total: used / total <= threshold_ratio
threshold = f"{threshold_ratio:.2f}"

dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
if all(is_free(used, total) for used, total in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
f'({threshold=}) {dur_s=:.02f}')
break

if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
f'{dur_s=:.02f} ({threshold=})')

time.sleep(5)

Expand Down
Loading