44Test the piecewise compilation with a simple model so that we 
55can exactly calculate the expected output and side effects. 
66""" 
7+ 
78import  pytest 
89import  torch 
910from  torch  import  nn 
10- from  torch .library  import  Library 
1111
1212from  vllm .compilation .counter  import  compilation_counter 
1313from  vllm .compilation .decorators  import  support_torch_compile 
1414from  vllm .config  import  (CompilationConfig , CompilationLevel , CUDAGraphMode ,
1515                         VllmConfig , set_current_vllm_config )
1616from  vllm .envs  import  VLLM_USE_V1 
1717from  vllm .forward_context  import  BatchDescriptor , set_forward_context 
18- from  vllm .utils  import  direct_register_custom_op 
19- 
20- global_counter  =  0 
21- 
22- # create a library to hold the custom op 
23- silly_lib  =  Library ("silly" , "FRAGMENT" )  # noqa 
24- 
25- 
26- def  silly_attention (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
27-                     out : torch .Tensor ) ->  None :
28-     global  global_counter 
29-     global_counter  +=  1 
30-     print (f"{ global_counter = }  )
31-     out .copy_ (q )
32-     out [0 ] +=  1 
33- 
34- 
35- def  silly_attention_fake (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
36-                          out : torch .Tensor ) ->  None :
37-     return 
38- 
3918
40- direct_register_custom_op (
41-     op_name = "attention" ,
42-     op_func = silly_attention ,
43-     mutates_args = ["out" ],
44-     fake_impl = silly_attention_fake ,
45-     target_lib = silly_lib ,
46- )
19+ # This import automatically registers `torch.ops.silly.attention` 
20+ from  ..silly_attention  import  get_global_counter , reset_global_counter 
4721
4822
4923@support_torch_compile  
@@ -59,8 +33,7 @@ def __init__(self,
5933    def  forward (self , x : torch .Tensor ) ->  torch .Tensor :
6034        """ 
6135        Overall effect: 
62-         x += 1 
63-         x[0] += 2 
36+         x = 3 * x + 19 
6437        global_counter += 2 
6538        """ 
6639        x  =  x  +  1 
@@ -78,6 +51,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7851
7952
8053@pytest .mark .parametrize ("use_inductor" , [True , False ]) 
54+ @torch .inference_mode () 
8155def  test_simple_piecewise_compile (use_inductor ):
8256    assert  VLLM_USE_V1 
8357
@@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor):
12195            model (torch .randn (1 ).cuda ())
12296
12397        input  =  torch .zeros (2 ).cuda ()
124-         global  global_counter 
125-         global_counter  =  0 
98+         reset_global_counter ()
12699        with  set_forward_context (
127100                None ,
128101                vllm_config = vllm_config ,
129102                cudagraph_runtime_mode = CUDAGraphMode .PIECEWISE ,
130103                batch_descriptor = BatchDescriptor (num_tokens = 2 , )):
131104            output  =  model (input )
132-         assert  global_counter  ==  2 
133-         assert  torch .allclose (output .cpu (), torch .tensor ([3. ,  1. ]))
105+         assert  get_global_counter ()  ==  2 
106+         assert  torch .allclose (output .cpu (), torch .tensor ([19.0 ,  19.0 ]))
0 commit comments