|  | 
| 6 | 6 | import pytest | 
| 7 | 7 | import torch._dynamo | 
| 8 | 8 | 
 | 
| 9 |  | -from tests.compile.backend import TestBackend | 
|  | 9 | +from tests.compile.backend import LazyInitPass, TestBackend | 
| 10 | 10 | from tests.models.utils import check_outputs_equal | 
| 11 | 11 | from tests.v1.attention.utils import (BatchSpec, _Backend, | 
| 12 | 12 |                                       create_common_attn_metadata) | 
| 13 | 13 | from vllm import LLM, SamplingParams | 
| 14 | 14 | from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant | 
| 15 |  | -from vllm.attention import Attention | 
|  | 15 | +from vllm.attention import Attention, AttentionMetadata | 
| 16 | 16 | from vllm.attention.selector import global_force_attn_backend_context_manager | 
| 17 | 17 | from vllm.compilation.fusion import QUANT_OPS | 
| 18 | 18 | from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass | 
| 19 | 19 | from vllm.compilation.fx_utils import find_op_nodes | 
| 20 | 20 | from vllm.compilation.noop_elimination import NoOpEliminationPass | 
|  | 21 | +from vllm.compilation.post_cleanup import PostCleanupPass | 
| 21 | 22 | from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, | 
| 22 | 23 |                          ModelConfig, PassConfig, SchedulerConfig, VllmConfig, | 
| 23 | 24 |                          set_current_vllm_config) | 
| @@ -104,7 +105,7 @@ def test_attention_fusion_v0(example_prompts, monkeypatch, model: str, | 
| 104 | 105 | 
 | 
| 105 | 106 |     # AttnFusionPass needs attention layers to be registered in config upon init | 
| 106 | 107 |     # so we initialize it during compilation. | 
| 107 |  | -    attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw) | 
|  | 108 | +    attn_pass = LazyInitPass(AttnFusionPass, vllm_config) | 
| 108 | 109 |     backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) | 
| 109 | 110 |     llm2 = LLM(model, | 
| 110 | 111 |                enforce_eager=True, | 
| @@ -197,7 +198,8 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, | 
| 197 | 198 |             device=self.device, | 
| 198 | 199 |         ) | 
| 199 | 200 | 
 | 
| 200 |  | -    def build_attn_metadata(self, batch_size: int, use_hnd: bool): | 
|  | 201 | +    def build_attn_metadata(self, batch_size: int, use_hnd: bool) \ | 
|  | 202 | +            -> AttentionMetadata: | 
| 201 | 203 |         """Initialize attention metadata.""" | 
| 202 | 204 | 
 | 
| 203 | 205 |         # Create common attn metadata | 
| @@ -447,9 +449,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | 
| 447 | 449 | 
 | 
| 448 | 450 |         # Create test backend with fusion passes enabled | 
| 449 | 451 |         noop_pass = NoOpEliminationPass(vllm_config) | 
| 450 |  | -        attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw | 
| 451 |  | -                                                                    ) | 
| 452 |  | -        test_backend = TestBackend(noop_pass, attn_pass) | 
|  | 452 | +        attn_pass = LazyInitPass(AttnFusionPass, vllm_config) | 
|  | 453 | +        cleanup_pass = PostCleanupPass(vllm_config) | 
|  | 454 | + | 
|  | 455 | +        test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass) | 
| 453 | 456 | 
 | 
| 454 | 457 |         # Compile model with fusion enabled | 
| 455 | 458 |         model_compiled = torch.compile(model_fused, | 
| @@ -485,6 +488,9 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, | 
| 485 | 488 |         test_backend.check_before_ops([QUANT_OPS[quant_key]], | 
| 486 | 489 |                                       fully_replaced=True) | 
| 487 | 490 | 
 | 
|  | 491 | +    # access the underlying `AttnFusionPass` on the `LazyInitPass` | 
|  | 492 | +    assert attn_pass.pass_.matched_count == sum(attn_fusion_supported) | 
|  | 493 | + | 
| 488 | 494 |     # Check attention ops in the graph before and after fusion | 
| 489 | 495 |     attn_nodes_pre = list(find_op_nodes(ATTN_OP, test_backend.graph_pre_pass)) | 
| 490 | 496 |     attn_nodes_post = list(find_op_nodes(ATTN_OP, | 
|  | 
0 commit comments