Skip to content

Commit a775c5c

Browse files
ProExpertProggjc0824
authored andcommitted
[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (vllm-project#23091), fix test (vllm-project#24376), and prep for custom op matching (vllm-project#24604) (vllm-project#24542)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: gaojc <1055866782@qq.com>
1 parent 7db1caf commit a775c5c

24 files changed

+407
-464
lines changed

tests/compile/backend.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import weakref
45
from collections.abc import Sequence
56
from copy import deepcopy
67
from typing import Callable, Union
@@ -10,7 +11,26 @@
1011

1112
from vllm.compilation.fx_utils import find_op_nodes
1213
from vllm.compilation.inductor_pass import InductorPass
13-
from vllm.config import get_current_vllm_config
14+
from vllm.compilation.pass_manager import with_pattern_match_debug
15+
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
16+
from vllm.config import VllmConfig, get_current_vllm_config
17+
18+
19+
class LazyInitPass(InductorPass):
20+
"""
21+
If there's a pass that we want to initialize lazily in a test,
22+
we can wrap it in LazyInitPass, which will initialize the pass when invoked
23+
and then immediately invoke it.
24+
"""
25+
26+
def __init__(self, pass_cls: type[VllmInductorPass],
27+
vllm_config: VllmConfig):
28+
self.pass_cls = pass_cls
29+
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle
30+
31+
def __call__(self, graph: fx.Graph) -> None:
32+
self.pass_ = self.pass_cls(self.vllm_config)
33+
self.pass_(graph)
1434

1535

1636
class TestBackend:
@@ -40,10 +60,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs):
4060
example_inputs,
4161
config_patches=self.inductor_config)
4262

63+
@with_pattern_match_debug
4364
def post_pass(self, graph: fx.Graph):
4465
self.graph_pre_pass = deepcopy(graph)
66+
67+
VllmInductorPass.dump_prefix = 0
4568
for pass_ in self.custom_passes:
4669
pass_(graph)
70+
VllmInductorPass.dump_prefix += 1
71+
72+
VllmInductorPass.dump_prefix = None
4773

4874
self.graph_post_pass = deepcopy(graph)
4975
# assign by reference, will reflect the final state of the graph

tests/compile/test_async_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def async_tp_pass_on_test_model(local_rank: int, world_size: int,
294294
compiled_model = torch.compile(model, backend=backend)
295295
compiled_model(hidden_states)
296296

297+
assert async_tp_pass.matched_count == 1
298+
297299
# In pre-nodes, all gather or reduce scatter should exist,
298300
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
299301
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)

tests/compile/test_config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import vllm
66
from vllm.compilation.counter import compilation_counter
7-
from vllm.config import VllmConfig
7+
from vllm.config import CompilationConfig, VllmConfig
88
from vllm.utils import _is_torch_equal_or_newer
99

1010

@@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
2626
assert not vllm_config.compilation_config.use_cudagraph
2727

2828

29+
def test_custom_op():
30+
# proper syntax
31+
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])
32+
33+
with pytest.raises(ValueError, match="Invalid syntax '"):
34+
_ = CompilationConfig(custom_ops=["quant_fp8"])
35+
36+
2937
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
3038
@pytest.mark.forked
3139
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends

tests/compile/test_functionalization.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from vllm import LLM, SamplingParams
99
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
1010
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
11-
from vllm.compilation.fusion import FUSED_OPS, FusionPass
11+
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
1212
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1313
from vllm.compilation.noop_elimination import NoOpEliminationPass
14+
from vllm.compilation.post_cleanup import PostCleanupPass
1415
from vllm.config import CompilationConfig, PassConfig, VllmConfig
1516
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1617
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
@@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
5859
vllm_config.compilation_config = CompilationConfig(
5960
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
6061
noop_pass = NoOpEliminationPass(vllm_config)
61-
fusion_pass = FusionPass.instance(vllm_config)
62+
fusion_pass = RMSNormQuantFusionPass(vllm_config)
63+
cleanup_pass = PostCleanupPass(vllm_config)
6264
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
6365

64-
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
65-
] if do_fusion else [noop_pass]
66+
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
67+
] if do_fusion else [noop_pass, cleanup_pass]
6668
func_pass = FixFunctionalizationPass(vllm_config)
6769
backend_func = TestBackend(*passes, func_pass)
6870
backend_no_func = TestBackend(*passes)

tests/compile/test_fusion.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import pytest
55
import torch
66

7-
import vllm.envs as envs
87
import vllm.plugins
98
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
10-
FusionPass)
9+
RMSNormQuantFusionPass)
1110
from vllm.compilation.noop_elimination import NoOpEliminationPass
11+
from vllm.compilation.post_cleanup import PostCleanupPass
1212
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
1313
VllmConfig)
1414
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -79,15 +79,15 @@ def ops_in_model_after(self):
7979

8080

8181
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
82-
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
83-
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
82+
@pytest.mark.parametrize("hidden_size", [64])
83+
@pytest.mark.parametrize("num_tokens", [257])
8484
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
8585
@pytest.mark.parametrize("static", [True, False])
8686
# cuda_force_torch used to test torch code path on platforms that
8787
# cutlass_fp8_supported() == True.
8888
@pytest.mark.parametrize("cuda_force_torch",
8989
[True, False] if cutlass_fp8_supported() else [True])
90-
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
90+
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
9191
reason="Only test on CUDA and ROCm")
9292
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
9393
cuda_force_torch):
@@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
104104
with vllm.config.set_current_vllm_config(vllm_config):
105105
# Reshape pass is needed for the fusion pass to work
106106
noop_pass = NoOpEliminationPass(vllm_config)
107-
fusion_pass = FusionPass.instance(vllm_config)
107+
fusion_pass = RMSNormQuantFusionPass(vllm_config)
108+
cleanup_pass = PostCleanupPass(vllm_config)
108109

109-
backend = TestBackend(noop_pass, fusion_pass)
110+
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
110111
model = TestModel(hidden_size, eps, static, cuda_force_torch)
111112

112113
# First dimension dynamic
@@ -128,6 +129,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
128129

129130
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
130131

132+
assert fusion_pass.matched_count == 2
133+
131134
# In pre-nodes, fp8 quant should be there and fused kernels should not
132135
backend.check_before_ops(model.ops_in_model_before())
133136

tests/compile/test_fusion_all_reduce.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.compilation.collective_fusion import AllReduceFusionPass
1010
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
1111
from vllm.compilation.noop_elimination import NoOpEliminationPass
12+
from vllm.compilation.post_cleanup import PostCleanupPass
1213
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
1314
ModelConfig, PassConfig, VllmConfig)
1415
from vllm.distributed import tensor_model_parallel_all_reduce
@@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
215216
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
216217
noop_pass = NoOpEliminationPass(vllm_config)
217218
func_pass = FixFunctionalizationPass(vllm_config)
219+
cleanup_pass = PostCleanupPass(vllm_config)
218220

219-
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
221+
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
222+
cleanup_pass)
220223

221224
token_num = batch_size * seq_len
222225
model = test_model_cls(hidden_size, token_num)
@@ -227,6 +230,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
227230
compiled_model = torch.compile(model, backend=backend)
228231
compiled_model(hidden_states, residual)
229232

233+
assert all_reduce_fusion_pass.matched_count == 1
230234
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
231235
backend.check_after_ops(model.ops_in_model_after())
232236
del all_reduce_fusion_pass

tests/compile/test_fusion_attn.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
import pytest
77
import torch._dynamo
88

9-
from tests.compile.backend import TestBackend
9+
from tests.compile.backend import LazyInitPass, TestBackend
1010
from tests.models.utils import check_outputs_equal
1111
from tests.v1.attention.utils import (BatchSpec, _Backend,
1212
create_common_attn_metadata)
1313
from vllm import LLM, SamplingParams
1414
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
15-
from vllm.attention import Attention
15+
from vllm.attention import Attention, AttentionMetadata
1616
from vllm.attention.selector import global_force_attn_backend_context_manager
1717
from vllm.compilation.fusion import QUANT_OPS
1818
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
1919
from vllm.compilation.fx_utils import find_op_nodes
2020
from vllm.compilation.noop_elimination import NoOpEliminationPass
21+
from vllm.compilation.post_cleanup import PostCleanupPass
2122
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
2223
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
2324
set_current_vllm_config)
@@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
104105

105106
# AttnFusionPass needs attention layers to be registered in config upon init
106107
# so we initialize it during compilation.
107-
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
108+
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
108109
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
109110
llm2 = LLM(model,
110111
enforce_eager=True,
@@ -197,7 +198,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
197198
device=self.device,
198199
)
199200

200-
def build_attn_metadata(self, batch_size: int, use_hnd: bool):
201+
def build_attn_metadata(self, batch_size: int, use_hnd: bool) \
202+
-> AttentionMetadata:
201203
"""Initialize attention metadata."""
202204

203205
# Create common attn metadata
@@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
447449

448450
# Create test backend with fusion passes enabled
449451
noop_pass = NoOpEliminationPass(vllm_config)
450-
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
451-
)
452-
test_backend = TestBackend(noop_pass, attn_pass)
452+
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
453+
cleanup_pass = PostCleanupPass(vllm_config)
454+
455+
test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)
453456

454457
# Compile model with fusion enabled
455458
model_compiled = torch.compile(model_fused,
@@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
485488
test_backend.check_before_ops([QUANT_OPS[quant_key]],
486489
fully_replaced=True)
487490

491+
# access the underlying `AttnFusionPass` on the `LazyInitPass`
492+
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
493+
488494
# Check attention ops in the graph before and after fusion
489495
attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass))
490496
attn_nodes_post = list(find_op_nodes(ATTN_OP,

tests/compile/test_sequence_parallelism.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
import vllm.envs as envs
88
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
9-
from vllm.compilation.fusion import FusionPass
9+
from vllm.compilation.fusion import RMSNormQuantFusionPass
1010
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
1111
from vllm.compilation.noop_elimination import NoOpEliminationPass
12+
from vllm.compilation.post_cleanup import PostCleanupPass
1213
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
14+
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
1315
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
1416
PassConfig, VllmConfig)
1517
from vllm.distributed import tensor_model_parallel_all_reduce
@@ -104,7 +106,7 @@ def __init__(self,
104106
# Initialize weights
105107
torch.nn.init.normal_(self.gate_proj, std=0.02)
106108

107-
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
109+
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
108110

109111
self.scale = torch.rand(1, dtype=torch.float32)
110112
# Create a weight that is compatible with torch._scaled_mm,
@@ -137,8 +139,7 @@ def forward(self, hidden_states, residual):
137139
# layer normalization
138140
norm_output, residual_output = self.norm(all_reduce, residual)
139141

140-
# for static input quantization
141-
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
142+
# scaled_mm with static input quantization
142143
fp8_linear_result = self.fp8_linear.apply(norm_output,
143144
self.w,
144145
self.wscale,
@@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
253254
dtype=dtype,
254255
seed=42)
255256

256-
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
257257
noop_pass = NoOpEliminationPass(vllm_config)
258+
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
258259
func_pass = FixFunctionalizationPass(vllm_config)
260+
cleanup_pass = PostCleanupPass(vllm_config)
259261

260-
passes_for_backend = [noop_pass, sequence_parallelism_pass]
262+
passes_for_backend: list[VllmInductorPass] = \
263+
[noop_pass, sequence_parallelism_pass]
261264

262265
if enable_fusion:
263-
fusion_pass = FusionPass.instance(vllm_config)
266+
fusion_pass = RMSNormQuantFusionPass(vllm_config)
264267
passes_for_backend.append(fusion_pass)
265268

269+
passes_for_backend.append(cleanup_pass)
270+
266271
backend_no_func = TestBackend(*passes_for_backend)
267272
backend_func = TestBackend(*passes_for_backend, func_pass)
268273

@@ -279,6 +284,8 @@ def sequence_parallelism_pass_on_test_model(
279284
compiled_model_func = torch.compile(model, backend=backend_func)
280285
compiled_model_func(hidden_states, residual)
281286

287+
assert sequence_parallelism_pass.matched_count == 1
288+
282289
# In pre-nodes, all reduce should be there,
283290
# reduce scatter and all gather should not
284291
backend_no_func.check_before_ops(model.ops_in_model_before())

tests/compile/test_silu_mul_quant_fusion.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# yapf: enable
1616
from vllm.compilation.fusion import QUANT_OPS
1717
from vllm.compilation.noop_elimination import NoOpEliminationPass
18+
from vllm.compilation.post_cleanup import PostCleanupPass
1819
from vllm.config import CompilationConfig, PassConfig, VllmConfig
1920
from vllm.model_executor.layers.activation import SiluAndMul
2021
from vllm.model_executor.layers.quantization.utils.quant_utils import (
@@ -69,6 +70,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
6970

7071
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
7172
super().__init__()
73+
from vllm.compilation.activation_quant_fusion import (
74+
silu_and_mul_nvfp4_quant_supported)
75+
assert silu_and_mul_nvfp4_quant_supported
76+
7277
self.silu_and_mul = SiluAndMul()
7378

7479
# create nvfp4 weight
@@ -127,7 +132,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
127132
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
128133
fusion_pass = ActivationQuantFusionPass(config)
129134

130-
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
135+
passes = [
136+
NoOpEliminationPass(config), fusion_pass,
137+
PostCleanupPass(config)
138+
]
139+
backend = TestBackend(*passes)
131140
model = model_class(hidden_size=hidden_size,
132141
cuda_force_torch=cuda_force_torch,
133142
x=x)
@@ -151,6 +160,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, dtype, model_class,
151160
atol=atol,
152161
rtol=rtol)
153162

163+
assert fusion_pass.matched_count == 1
164+
154165
# In pre-nodes, quant op should be present and fused kernels should not
155166
backend.check_before_ops(model.ops_in_model_before())
156167

0 commit comments

Comments
 (0)