Skip to content

[Bugfix] Re-enable use_cudagraph in vLLM v1 #19299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _test_simple_piecewise_compile(*, use_inductor):
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
num_cudagraph_captured=
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

Expand Down
6 changes: 3 additions & 3 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _test_toy_llama(*, use_inductor):
num_piecewise_graphs_seen=0,
num_piecewise_capturable_graphs_seen=0,
num_backend_compilations=0,
num_cudagraph_caputured=0,
num_cudagraph_captured=0,
):
outputs.append(
run_model(llama_config, use_inductor=False, use_compile=False))
Expand All @@ -343,7 +343,7 @@ def _test_toy_llama(*, use_inductor):
num_piecewise_graphs_seen=1,
num_piecewise_capturable_graphs_seen=1,
num_backend_compilations=1, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=
num_cudagraph_captured=
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
**kwargs,
):
Expand All @@ -361,7 +361,7 @@ def _test_toy_llama(*, use_inductor):
llama_config.num_layers, # 1 + num_layers
num_backend_compilations=1 +
llama_config.num_layers, # num_piecewise_capturable_graphs_seen
num_cudagraph_caputured=2 *
num_cudagraph_captured=2 *
(1 + llama_config.num_layers
), # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
Expand Down
43 changes: 43 additions & 0 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
set_current_vllm_config)

from .piecewise.test_simple import SillyModel


@pytest.fixture(scope="function", autouse=True)
def use_v1(monkeypatch):
"""
TODO(rzou): The rest of tests/compile runs VLLM_USE_V1=0 right now,
I'll switch them over later.
"""
monkeypatch.setenv('VLLM_USE_V1', '1')


@pytest.mark.parametrize("enabled", [True, False])
def test_use_cudagraphs(enabled):
assert vllm.envs.VLLM_USE_V1
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=enabled,
cudagraph_capture_sizes=[100],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')

inputs = torch.randn(100, device="cuda")

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_cudagraph_captured=1 if enabled else 0,
):
# first run is warmup
model(inputs)
# second run does CUDAGraphs recording (if enabled)
model(inputs)
2 changes: 1 addition & 1 deletion vllm/compilation/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class CompilationCounter:
# not including the splitting ops
num_piecewise_capturable_graphs_seen: int = 0
num_backend_compilations: int = 0
num_cudagraph_caputured: int = 0
num_cudagraph_captured: int = 0
# InductorAdapter.compile calls
num_inductor_compiles: int = 0
# EagerAdapter.compile calls
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/cuda_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __call__(self, *args) -> Any:
entry.output = weak_ref_tensors(output)
entry.cudagraph = cudagraph

compilation_counter.num_cudagraph_caputured += 1
compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
Expand Down
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3918,12 +3918,14 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""

# CudaGraph compilation
use_cudagraph: bool = False
use_cudagraph: bool = envs.VLLM_USE_V1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the right default - this should probably depend on the compilation level.

Copy link
Collaborator Author

@zou3519 zou3519 Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

False was the right default for V0.

For V1, we just override all compilation levels to be PIECEWISE, so True is the right default for VLLM_USE_V1. We force override -O1 and -O2 to be -O3. I'll fix this in a separate PR (I'm slowly fixing the TODO linked there), but this PR should have no behavior changes.

Also, the use_cudagraph flag only applies to PIECEWISE, otherwise it is ignored. We could figure out how to not ignore it for the other compilation levels, but that is separate work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should give it a better name like piece_wise_use_cudagraph or not.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be done as a follow up PR, let's merge it to fix the issue first.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is fine because this flag also controls full cudagraphs. But perhaps we can consolidate them

"""Whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses, and all
splitting ops write their outputs to input buffers.
In the vLLM V1 Engine, this flag only applies for
CompilationLevel.PIECEWISE (aka -O3).
Note that this is orthogonal to the cudagraph capture logic
outside of compilation.
TODO: move outside cudagraph logic into compilation.
Expand Down Expand Up @@ -4425,7 +4427,6 @@ def __post_init__(self):
# FIXME(rob): Add function to set all of these.
if not self.compilation_config.custom_ops:
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_noop = False
Expand Down