11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
3+ from  typing  import  Optional 
34
45import  pytest 
56import  torch 
@@ -34,15 +35,15 @@ class Relu3(ReLUSquaredActivation):
3435    [ 
3536        # Default values based on compile level  
3637        # - All by default (no Inductor compilation)  
37-         ("" , 0 , False , [True ] *  4 , True ), 
38-         ("" , 1 , True , [True ] *  4 , True ), 
39-         ("" , 2 , False , [True ] *  4 , True ), 
38+         (None , 0 , False , [True ] *  4 , True ), 
39+         (None , 1 , True , [True ] *  4 , True ), 
40+         (None , 2 , False , [True ] *  4 , True ), 
4041        # - None by default (with Inductor)  
41-         ("" , 3 , True , [False ] *  4 , False ), 
42-         ("" , 4 , True , [False ] *  4 , False ), 
42+         (None , 3 , True , [False ] *  4 , False ), 
43+         (None , 4 , True , [False ] *  4 , False ), 
4344        # - All by default (without Inductor)  
44-         ("" , 3 , False , [True ] *  4 , True ), 
45-         ("" , 4 , False , [True ] *  4 , True ), 
45+         (None , 3 , False , [True ] *  4 , True ), 
46+         (None , 4 , False , [True ] *  4 , True ), 
4647        # Explicitly enabling/disabling  
4748        #  
4849        # Default: all  
@@ -54,7 +55,7 @@ class Relu3(ReLUSquaredActivation):
5455        # All but SiluAndMul  
5556        ("all,-silu_and_mul" , 2 , True , [1 , 0 , 1 , 1 ], True ), 
5657        # All but ReLU3 (even if ReLU2 is on)  
57-         ("-relu3,relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ), 
58+         ("-relu3,+ relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ), 
5859        # RMSNorm and SiluAndMul  
5960        ("none,-relu3,+rms_norm,+silu_and_mul" , 4 , False , [1 , 1 , 0 , 0 ], False ), 
6061        # All but RMSNorm  
@@ -67,12 +68,13 @@ class Relu3(ReLUSquaredActivation):
6768        # All but RMSNorm  
6869        ("all,-rms_norm" , 4 , True , [0 , 1 , 1 , 1 ], True ), 
6970    ]) 
70- def  test_enabled_ops (env : str , torch_level : int , use_inductor : bool ,
71+ def  test_enabled_ops (env : Optional [ str ] , torch_level : int , use_inductor : bool ,
7172                     ops_enabled : list [int ], default_on : bool ):
73+     custom_ops  =  env .split (',' ) if  env  else  []
7274    vllm_config  =  VllmConfig (
7375        compilation_config = CompilationConfig (use_inductor = bool (use_inductor ),
7476                                             level = torch_level ,
75-                                              custom_ops = env . split ( "," ) ))
77+                                              custom_ops = custom_ops ))
7678    with  set_current_vllm_config (vllm_config ):
7779        assert  CustomOp .default_on () ==  default_on 
7880
0 commit comments