Skip to content

Commit d4932d0

Browse files
ZJY0516skyloevil
authored andcommitted
[CI] execute all piecewise compilation tests together (vllm-project#24502)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
1 parent a48b56a commit d4932d0

File tree

6 files changed

+81
-117
lines changed

6 files changed

+81
-117
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,7 @@ steps:
379379
- tests/compile
380380
commands:
381381
- pytest -v -s compile/test_basic_correctness.py
382-
# these tests need to be separated, cannot combine
383-
- pytest -v -s compile/piecewise/test_simple.py
384-
- pytest -v -s compile/piecewise/test_toy_llama.py
385-
- pytest -v -s compile/piecewise/test_full_cudagraph.py
386-
- pytest -v -s compile/piecewise/test_multiple_graphs.py
382+
- pytest -v -s compile/piecewise/
387383

388384
- label: PyTorch Fullgraph Test # 20min
389385
timeout_in_minutes: 30

tests/compile/piecewise/test_multiple_graphs.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
Test (piecewise) compilation with a simple model where multiple submodules
55
are compiled and graph captured separately.
66
"""
7+
78
import torch
89
from torch import nn
9-
from torch.library import Library
1010

1111
from vllm.compilation.backends import set_model_tag
1212
from vllm.compilation.counter import compilation_counter
@@ -15,38 +15,16 @@
1515
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
1616
VllmConfig, set_current_vllm_config)
1717
from vllm.forward_context import BatchDescriptor, set_forward_context
18-
from vllm.utils import direct_register_custom_op
1918

20-
# create a library to hold the custom op
21-
silly_lib = Library("silly", "FRAGMENT") # noqa
19+
# This import automatically registers `torch.ops.silly.attention`
20+
from .. import silly_attention # noqa: F401
2221

2322
BATCH_SIZE = 32
2423
MLP_SIZE = 128
2524
HIDDEN_SIZE = 1024
2625
RANDOM_SEED = 0
2726

2827

29-
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
30-
out: torch.Tensor) -> None:
31-
out.copy_(q)
32-
out += k
33-
out += v
34-
35-
36-
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
37-
out: torch.Tensor) -> None:
38-
return
39-
40-
41-
direct_register_custom_op(
42-
op_name="attention",
43-
op_func=silly_attention,
44-
mutates_args=["out"],
45-
fake_impl=silly_attention_fake,
46-
target_lib=silly_lib,
47-
)
48-
49-
5028
@support_torch_compile
5129
class ParentModel(nn.Module):
5230

tests/compile/piecewise/test_simple.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,20 @@
44
Test the piecewise compilation with a simple model so that we
55
can exactly calculate the expected output and side effects.
66
"""
7+
78
import pytest
89
import torch
910
from torch import nn
10-
from torch.library import Library
1111

1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.decorators import support_torch_compile
1414
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
1515
VllmConfig, set_current_vllm_config)
1616
from vllm.envs import VLLM_USE_V1
1717
from vllm.forward_context import BatchDescriptor, set_forward_context
18-
from vllm.utils import direct_register_custom_op
19-
20-
global_counter = 0
21-
22-
# create a library to hold the custom op
23-
silly_lib = Library("silly", "FRAGMENT") # noqa
24-
25-
26-
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
27-
out: torch.Tensor) -> None:
28-
global global_counter
29-
global_counter += 1
30-
print(f"{global_counter=}")
31-
out.copy_(q)
32-
out[0] += 1
33-
34-
35-
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
36-
out: torch.Tensor) -> None:
37-
return
38-
3918

40-
direct_register_custom_op(
41-
op_name="attention",
42-
op_func=silly_attention,
43-
mutates_args=["out"],
44-
fake_impl=silly_attention_fake,
45-
target_lib=silly_lib,
46-
)
19+
# This import automatically registers `torch.ops.silly.attention`
20+
from ..silly_attention import get_global_counter, reset_global_counter
4721

4822

4923
@support_torch_compile
@@ -59,8 +33,7 @@ def __init__(self,
5933
def forward(self, x: torch.Tensor) -> torch.Tensor:
6034
"""
6135
Overall effect:
62-
x += 1
63-
x[0] += 2
36+
x = 3 * x + 19
6437
global_counter += 2
6538
"""
6639
x = x + 1
@@ -78,6 +51,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7851

7952

8053
@pytest.mark.parametrize("use_inductor", [True, False])
54+
@torch.inference_mode()
8155
def test_simple_piecewise_compile(use_inductor):
8256
assert VLLM_USE_V1
8357

@@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
12195
model(torch.randn(1).cuda())
12296

12397
input = torch.zeros(2).cuda()
124-
global global_counter
125-
global_counter = 0
98+
reset_global_counter()
12699
with set_forward_context(
127100
None,
128101
vllm_config=vllm_config,
129102
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
130103
batch_descriptor=BatchDescriptor(num_tokens=2, )):
131104
output = model(input)
132-
assert global_counter == 2
133-
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
105+
assert get_global_counter() == 2
106+
assert torch.allclose(output.cpu(), torch.tensor([19.0, 19.0]))

tests/compile/piecewise/test_toy_llama.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,15 @@
1414
import pytest
1515
import torch
1616
from torch import nn
17-
from torch.library import Library
1817

1918
from vllm.compilation.counter import compilation_counter
2019
from vllm.compilation.decorators import support_torch_compile
2120
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
2221
VllmConfig, set_current_vllm_config)
2322
from vllm.forward_context import BatchDescriptor, set_forward_context
24-
from vllm.utils import direct_register_custom_op
2523

26-
# create a library to hold the custom op
27-
silly_lib = Library("silly", "FRAGMENT") # noqa
28-
29-
30-
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
31-
out: torch.Tensor) -> None:
32-
out.copy_(q)
33-
out += k
34-
out += v
35-
36-
37-
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
38-
out: torch.Tensor) -> None:
39-
return
40-
41-
42-
direct_register_custom_op(
43-
op_name="attention",
44-
op_func=silly_attention,
45-
mutates_args=["out"],
46-
fake_impl=silly_attention_fake,
47-
target_lib=silly_lib,
48-
)
24+
# This import automatically registers `torch.ops.silly.attention`
25+
from .. import silly_attention # noqa: F401
4926

5027

5128
@dataclass

tests/compile/silly_attention.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Shared PyTorch custom silly attention for compilation tests.
5+
Centralizes custom operation definitions to avoid duplicate registrations.
6+
"""
7+
8+
import torch
9+
from torch.library import Library
10+
11+
from vllm.utils import direct_register_custom_op
12+
13+
# Shared library for all compilation test operations
14+
# Using "silly" namespace to match existing test expectations
15+
# import this file will automatically register
16+
# torch ops for testing (like silly.attention)
17+
silly_lib = Library("silly", "FRAGMENT")
18+
19+
# Global counter that counts the number of times attention is invoked
20+
_global_counter = 0
21+
22+
23+
def get_global_counter():
24+
"""Get the current global counter value"""
25+
return _global_counter
26+
27+
28+
def reset_global_counter():
29+
"""Reset the global counter to 0"""
30+
global _global_counter
31+
_global_counter = 0
32+
33+
34+
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
35+
out: torch.Tensor) -> None:
36+
"""
37+
Unified attention implementation that depends on
38+
all inputs and affects the output.
39+
Always increments a global counter that tests can use or ignore.
40+
"""
41+
global _global_counter
42+
43+
# Always increment the global counter
44+
_global_counter += 1
45+
46+
# Unified implementation that depends on all inputs
47+
out.copy_(q + k + v)
48+
49+
50+
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
51+
out: torch.Tensor) -> None:
52+
"""Fake implementation for testing"""
53+
return
54+
55+
56+
# Register the unified attention operation
57+
direct_register_custom_op(
58+
op_name="attention",
59+
op_func=silly_attention,
60+
mutates_args=["out"],
61+
fake_impl=silly_attention_fake,
62+
target_lib=silly_lib,
63+
)

tests/compile/test_decorator.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,21 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import torch
44
from torch import nn
5-
from torch.library import Library
65

76
from vllm.compilation.counter import compilation_counter
87
from vllm.compilation.decorators import (ignore_torch_compile,
98
support_torch_compile)
109
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
1110
CUDAGraphMode, VllmConfig, set_current_vllm_config)
1211
from vllm.forward_context import BatchDescriptor, set_forward_context
13-
from vllm.utils import direct_register_custom_op
1412

15-
# create a library to hold the custom op
16-
silly_lib = Library("silly", "FRAGMENT") # noqa
13+
# This import automatically registers `torch.ops.silly.attention`
14+
from . import silly_attention # noqa: F401
1715

1816
BATCH_SIZE = 32
1917
MLP_SIZE = 128
2018

2119

22-
def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
23-
out: torch.Tensor) -> None:
24-
out.copy_(q)
25-
out += k
26-
out += v
27-
28-
29-
def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
30-
out: torch.Tensor) -> None:
31-
return
32-
33-
34-
direct_register_custom_op(
35-
op_name="attention",
36-
op_func=silly_attention,
37-
mutates_args=["out"],
38-
fake_impl=silly_attention_fake,
39-
target_lib=silly_lib,
40-
)
41-
42-
4320
@torch.inference_mode
4421
def run_model(vllm_config: VllmConfig, model: nn.Module,
4522
cudagraph_runtime_mode: CUDAGraphMode):
@@ -151,7 +128,7 @@ class C(B):
151128
run_model(vllm_config, mod_C, cudagraph_runtime_mode)
152129

153130

154-
# Only enable torch.compile if
131+
# Only enable torch.compile if
155132
# vllm_config.cache_config.kv_sharing_fast_prefill=True
156133
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
157134
kv_sharing_fast_prefill)
@@ -173,7 +150,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
173150
return x
174151

175152

176-
# Only enable torch.compile if
153+
# Only enable torch.compile if
177154
# vllm_config.cache_config.kv_sharing_fast_prefill=False
178155
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
179156
cache_config.kv_sharing_fast_prefill)

0 commit comments

Comments
 (0)