Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
92b1733
FA2 and FlashInfer Full cuda graph support
fhl2000 Jun 25, 2025
58ce477
fix the arch support in CMakeLists.txt to include 8.9
fhl2000 Jun 25, 2025
c2c5fea
Refactors
fhl2000 Jun 25, 2025
1606880
refactors
fhl2000 Jun 25, 2025
806432a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 25, 2025
7c5df45
refactor
fhl2000 Jun 25, 2025
c7a9424
Add check for separate_attention_routine flag
fhl2000 Jun 25, 2025
e8b9296
fix typo error
fhl2000 Jun 26, 2025
94d0b79
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jun 27, 2025
a67c698
refactors and rearchitect cuda graph logic
fhl2000 Jun 28, 2025
da110af
Refactors
fhl2000 Jun 28, 2025
deaf0fe
Delect one commit
fhl2000 Jun 28, 2025
02ca154
Add support for force_no_split_graph
fhl2000 Jun 28, 2025
fa0d25c
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 1, 2025
5108bef
Huge refactors to separete cudagraph logic from vllm compilation
fhl2000 Jul 5, 2025
1c1873d
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 5, 2025
7d4667a
refactors
fhl2000 Jul 5, 2025
fedff47
fix errors
fhl2000 Jul 5, 2025
833ac56
fix small error by lazy import
fhl2000 Jul 5, 2025
d57257d
handle lint-and-deploy errors for cpu execution
fhl2000 Jul 5, 2025
8b7ea7a
remove redundents
fhl2000 Jul 5, 2025
328615d
Clear
fhl2000 Jul 6, 2025
debc682
Big refactors
fhl2000 Jul 9, 2025
cad6c39
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 9, 2025
dc455ee
cleanup
fhl2000 Jul 10, 2025
620a728
fix warmup
fhl2000 Jul 10, 2025
b1e6978
Commit suggestion: Update vllm/config.py
fhl2000 Jul 10, 2025
beee69a
commit suggestion2: Update vllm/config.py
fhl2000 Jul 10, 2025
21b1a8d
fix enforce_eager
fhl2000 Jul 10, 2025
ec79af7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 10, 2025
210359a
small cleanup for pre-commit
fhl2000 Jul 10, 2025
11263e0
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 11, 2025
9a38a4e
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 12, 2025
699aff3
refactors
fhl2000 Jul 13, 2025
ef3d9d9
resolve yapf conflicts with isort
fhl2000 Jul 13, 2025
658565e
fixes
fhl2000 Jul 13, 2025
15e2b4a
fix global graph pool issue
fhl2000 Jul 13, 2025
4253dbf
fix refactors
fhl2000 Jul 13, 2025
2783e26
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 14, 2025
1b54962
more refactors
fhl2000 Jul 14, 2025
fb2a3c7
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 17, 2025
d6269bd
refactors for and more
fhl2000 Jul 17, 2025
2e1304c
fix pre-commit
fhl2000 Jul 17, 2025
db22ca5
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 18, 2025
72d40e6
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 20, 2025
0c79e53
change cudagraph dispatching logics; runtime style->runtime mode
fhl2000 Jul 21, 2025
75db3a6
pass pre-commit
fhl2000 Jul 21, 2025
0bca4c4
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 23, 2025
9d2f148
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 24, 2025
60bdc61
fix bug when cudagraph_separate_routine==False
fhl2000 Jul 24, 2025
9036bd2
recover FlashInfer from main branch
fhl2000 Jul 24, 2025
89ec3aa
address comments and clean up
fhl2000 Jul 26, 2025
4b991a3
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 26, 2025
614f6ea
clean up
fhl2000 Jul 26, 2025
c049627
fix
fhl2000 Jul 26, 2025
e69e488
add tests; more docs
fhl2000 Jul 27, 2025
835086a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 27, 2025
534410e
clean up
fhl2000 Jul 27, 2025
618f7c0
small fix
fhl2000 Jul 27, 2025
1b343eb
add more docs
fhl2000 Jul 27, 2025
532f245
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 28, 2025
431a726
simplify the logic
fhl2000 Jul 28, 2025
19faeda
fix CI failures
fhl2000 Jul 29, 2025
348a117
fix CI failures again
fhl2000 Jul 29, 2025
fc5e37a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 29, 2025
4d9829f
fix pre-commit
fhl2000 Jul 29, 2025
7773608
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 29, 2025
a692bb6
fix CI
fhl2000 Jul 29, 2025
543f264
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 30, 2025
3e5959a
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 31, 2025
aa35551
fix errors;move default initialization of cudagraph_mode to __post_in…
fhl2000 Jul 31, 2025
bad2710
fix a potential bug
fhl2000 Jul 31, 2025
f175c16
Merge branch 'vllm-project:main' into full_cudagraph_FA2_FlashInfer
fhl2000 Jul 31, 2025
9916a75
Merge remote-tracking branch 'origin/main' into pr-20059
LucasWilkinson Jul 31, 2025
81d7561
wip rework cudagraph_mode
LucasWilkinson Aug 1, 2025
0137d84
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 2, 2025
1bfb855
fix and re-enable FlashInfer full cudagraph
fhl2000 Aug 2, 2025
24c40ab
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 2, 2025
95d94f8
fix some CI tests
fhl2000 Aug 2, 2025
e7763ef
fallback
LucasWilkinson Aug 4, 2025
803a185
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 4, 2025
645accf
warn perferred
LucasWilkinson Aug 4, 2025
5029a6a
fix bugs and some refactors;temporarily add FULL_DOUBLE mode
fhl2000 Aug 5, 2025
fef7eee
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 5, 2025
e796196
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 5, 2025
816024e
fix incorrectly infering type from CUDAGraphWrapper
fhl2000 Aug 5, 2025
651f729
fix and refactor cudagraph_mode checkings
fhl2000 Aug 6, 2025
38ddeaf
remove full double
LucasWilkinson Aug 6, 2025
9ca04ed
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 7, 2025
a7adfae
Merge remote-tracking branch 'origin/main' into fhl2000/full_cudagrap…
LucasWilkinson Aug 7, 2025
14e83f5
Merge branch 'fhl2000_full_cudagraph_FA2_FlashInfer_merge' into full_…
LucasWilkinson Aug 7, 2025
9cc6b93
fix
LucasWilkinson Aug 7, 2025
1e97920
fix
LucasWilkinson Aug 7, 2025
25b6242
cleanup
LucasWilkinson Aug 7, 2025
85f20bf
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 8, 2025
766eb7c
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 8, 2025
b0374be
deprecation
LucasWilkinson Aug 8, 2025
a160dd4
migrate flags
LucasWilkinson Aug 8, 2025
c2dc791
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 8, 2025
028f119
cleanup
LucasWilkinson Aug 8, 2025
43db16d
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 9, 2025
2cad036
fix some unit tests
LucasWilkinson Aug 9, 2025
6839e88
more cleanup
LucasWilkinson Aug 9, 2025
04ed99a
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 9, 2025
3f2b279
fix more unit tests
LucasWilkinson Aug 9, 2025
d500150
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
fhl2000 Aug 10, 2025
a56d549
fix is_attention_splitting;fix new mamba_attn cg support
fhl2000 Aug 10, 2025
3499d7b
wip
LucasWilkinson Aug 10, 2025
83d4e7c
stabalize unit test
LucasWilkinson Aug 11, 2025
bf8a51d
cleanup
LucasWilkinson Aug 11, 2025
d1f62e4
unit test fix
LucasWilkinson Aug 11, 2025
7e19ca4
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 11, 2025
c722f2c
refactor
LucasWilkinson Aug 11, 2025
1937615
remove accidentally committed file
LucasWilkinson Aug 11, 2025
ce9cc82
fix XPU tests
LucasWilkinson Aug 11, 2025
f3561f9
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 11, 2025
9d6b189
fix xpu
LucasWilkinson Aug 11, 2025
3c4b532
match HPU cudagraph handling + down grade log
LucasWilkinson Aug 11, 2025
bed9576
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 12, 2025
19f7447
fix xpu
LucasWilkinson Aug 12, 2025
0122313
unit test fixes
LucasWilkinson Aug 12, 2025
3a2041b
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
641b10b
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
974c707
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
LucasWilkinson Aug 14, 2025
3805f3f
Apply suggestions from code review
LucasWilkinson Aug 15, 2025
f2c437a
review comments
LucasWilkinson Aug 15, 2025
1ff41d8
Update vllm/v1/worker/gpu_model_runner.py
LucasWilkinson Aug 15, 2025
af2a38c
Merge remote-tracking branch 'origin/main' into full_cudagraph_FA2_Fl…
fhl2000 Aug 15, 2025
f751e50
Merge branch 'main' into full_cudagraph_FA2_FlashInfer
fhl2000 Aug 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 125 additions & 128 deletions tests/compile/piecewise/test_full_cudagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import contextlib
import os
import weakref
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional

import pytest

Expand Down Expand Up @@ -32,69 +33,133 @@ def temporary_environ(env_vars):
os.environ[k] = v


@pytest.fixture(scope="class")
def llm_pair(request):
model = request.param

with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(full_cuda_graph=True),
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
)

# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise

wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)


@pytest.fixture(scope="class")
def cutlass_mla_llm_pair(request):
model = request.param

# force V1 engine and Cutlass MLA backend
with temporary_environ({
@dataclass
class BackendConfig:
name: str
env_vars: dict
comp_config: dict
specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
backend_configs = {
# FA3 on Hopper
"FA3":
BackendConfig(name="FA3",
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
comp_config={
"cudagraph_mode": "FULL",
},
specific_gpu_arch=(9, 0)),
# FlashMLA on Hopper
"FlashMLA":
BackendConfig(name="FlashMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
name="CutlassMLA",
env_vars={
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
}):
},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
},
specific_gpu_arch=(10, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
comp_config={
"cudagraph_mode": "FULL",
}),
# Triton Attention
"TritonAttn":
BackendConfig(name="TritonAttn",
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
comp_config={
"cudagraph_mode": "FULL",
}),
# FlashInfer
"FlashInfer":
BackendConfig(name="FlashInfer",
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
comp_config={
"cudagraph_mode": "FULL_AND_PIECEWISE",
}),
}

test_params_full_cudagraph = []

# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))

# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
test_params_full_cudagraph.append(
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))


@pytest.fixture(scope="class")
def llm_pair(request):
model, backend_config = request.param

# Dynamically skip test if GPU capability is not met
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
!= current_platform.get_device_capability():
if backend_config.specific_gpu_arch == (9, 0):
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
elif backend_config.specific_gpu_arch == (10, 0):
pytest.skip("Only Blackwell GPUs support Cutlass MLA")

env_vars = {
"VLLM_USE_V1": "1",
# Force native sampler to avoid potential nondeterminism in FlashInfer
# when per-request generators are not used in V1.
"VLLM_USE_FLASHINFER_SAMPLER": "0",
**backend_config.env_vars,
}
with temporary_environ(env_vars):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
),
max_num_seqs=128,
compilation_config=\
CompilationConfig(**backend_config.comp_config),
generation_config="vllm",
seed=42,
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
gpu_memory_utilization=0.43,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
max_num_seqs=128,
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
generation_config="vllm",
seed=42,
)

# PyTest caches the fixture values so we use weakref.proxy to enable GC
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise
Expand All @@ -105,51 +170,7 @@ def cutlass_mla_llm_pair(request):
)


@pytest.mark.parametrize(
"cutlass_mla_llm_pair",
[
# use an MLA model
"deepseek-ai/DeepSeek-V2-Lite",
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
reason="Only Blackwell GPUs support Cutlass MLA")
class TestFullCUDAGraphCutlassMLA:
"""
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
"""

@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(8, 8),
])
def test_full_cudagraph_sm100_cutlass_mla(
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
LLM]):
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair

prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)

piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text


@pytest.mark.parametrize(
"llm_pair",
[
# Model names for the llm_pair fixture
"deepseek-ai/DeepSeek-V2-Lite",
"Qwen/Qwen2-1.5B-Instruct"
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
class TestFullCUDAGraph:
"""
Use a class such that an llm pair is constructed once for all
Expand Down Expand Up @@ -178,55 +199,31 @@ def test_full_cudagraph(self, batch_size, max_tokens,
full cudagraph compilation works for padded cases too.
"""

piecewise_llm, full_cudagraph_llm = llm_pair
full_cudagraph_llm, piecewise_llm = llm_pair

prompts = ["Hello, my name is"] * batch_size
prompts = ["the quick brown fox"] * batch_size
# Use purely greedy decoding to avoid top-p truncation sensitivity
# that can amplify tiny numeric differences across runtimes.
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
top_p=1.0)

piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

# Check that all responses are the same
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text


@pytest.mark.parametrize(
"model, supported",
[
("Qwen/Qwen2-1.5B-Instruct", True),
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
("deepseek-ai/DeepSeek-V2-Lite", False),
])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}), ExitStack() as stack:
if not supported:
stack.enter_context(pytest.raises(RuntimeError))

llm = LLM(model=model,
max_num_seqs=256,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[64, 256, 512]))
llm.generate(["Hello, my name is"] * 10)
assert piecewise_res.outputs[0].text.lower() == \
full_res.outputs[0].text.lower()


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
# Flex_Attention is not supported with full cuda graph
}), pytest.raises(RuntimeError):
LLM(model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(full_cuda_graph=True))
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
33 changes: 25 additions & 8 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
VllmConfig, set_current_vllm_config)
from vllm.envs import VLLM_USE_V1
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils import direct_register_custom_op

global_counter = 0
Expand Down Expand Up @@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
), set_forward_context({}, vllm_config=vllm_config):

), set_forward_context(None,
vllm_config=vllm_config): # background context
# warm up with background context
model(inputs)

model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
# capturing/replaying should under context of cudagraph dispatching
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
model(torch.randn(2).cuda())
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=1, )):
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
global_counter = 0
output = model(input)
with set_forward_context(
None,
vllm_config=vllm_config,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
batch_descriptor=BatchDescriptor(num_tokens=2, )):
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
Loading