55import  torch 
66
77import  vllm .envs  as  envs 
8- from  vllm  import  LLM , SamplingParams 
98from  vllm .compilation .activation_quant_fusion  import  ActivationQuantFusionPass 
109from  vllm .compilation .fix_functionalization  import  FixFunctionalizationPass 
11- from  vllm .compilation .fusion  import  FUSED_OPS ,  RMSNormQuantFusionPass 
10+ from  vllm .compilation .fusion  import  RMSNormQuantFusionPass 
1211from  vllm .compilation .fx_utils  import  find_auto_fn , find_auto_fn_maybe , is_func 
1312from  vllm .compilation .noop_elimination  import  NoOpEliminationPass 
1413from  vllm .compilation .post_cleanup  import  PostCleanupPass 
1514from  vllm .config  import  CompilationConfig , PassConfig , VllmConfig 
15+ from  vllm .model_executor .layers .activation  import  SiluAndMul 
16+ from  vllm .model_executor .layers .layernorm  import  RMSNorm 
1617from  vllm .model_executor .layers .quantization .utils .quant_utils  import  (
17-     QuantKey , kFp8DynamicTokenSym , kFp8StaticTensorSym )
18+     GroupShape )
19+ from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
20+     Fp8LinearOp )
21+ from  vllm .model_executor .layers .rotary_embedding  import  get_rope 
22+ from  vllm .platforms  import  current_platform 
1823
1924from  .backend  import  TestBackend 
2025
21- OPS_IN_MODEL  =  [
22-     torch .ops ._C .rotary_embedding .default ,
23-     torch .ops ._C .fused_add_rms_norm .default ,
24- ]
26+ TEST_FP8  =  current_platform .supports_fp8 ()
27+ FP8_DTYPE  =  current_platform .fp8_dtype ()
28+ 
29+ 
30+ class  TestSiluMul (torch .nn .Module ):
31+ 
32+     def  __init__ (self , hidden_size : int  =  128 ):
33+         super ().__init__ ()
34+         self .silu_and_mul  =  SiluAndMul ()
35+         self .wscale  =  torch .rand (1 , dtype = torch .float32 )
36+         self .scale  =  torch .rand (1 , dtype = torch .float32 )
37+ 
38+         if  TEST_FP8 :
39+             self .w  =  torch .rand (hidden_size ,
40+                                 hidden_size ).to (dtype = FP8_DTYPE ).t ()
41+             self .fp8_linear  =  Fp8LinearOp (
42+                 act_quant_static = True ,
43+                 act_quant_group_shape = GroupShape .PER_TENSOR ,
44+             )
45+ 
46+     def  forward (self , x ):
47+         y  =  self .silu_and_mul (x )
48+         if  TEST_FP8 :
49+             x2  =  self .fp8_linear .apply (y ,
50+                                        self .w ,
51+                                        self .wscale ,
52+                                        input_scale = self .wscale )
53+             return  x2 
54+         else :
55+             return  y 
56+ 
57+     def  example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
58+         dtype  =  torch .float16  if  TEST_FP8  else  torch .float32 
59+         return  (torch .rand (num_tokens , hidden_size  *  2 , dtype = dtype ), )
60+ 
61+     def  ops_in_model (self , do_fusion ):
62+         if  TEST_FP8  and  do_fusion :
63+             return  [torch .ops ._C .silu_and_mul_quant .default ]
64+         else :
65+             return  [torch .ops ._C .silu_and_mul .default ]
66+ 
67+     def  ops_not_in_model (self ):
68+         return  []
69+ 
70+ 
71+ class  TestFusedAddRMSNorm (torch .nn .Module ):
72+ 
73+     def  __init__ (self , hidden_size = 16 , intermediate_size = 32 ):
74+         super ().__init__ ()
75+         self .hidden_size  =  hidden_size 
76+         self .intermediate_size  =  intermediate_size 
77+ 
78+         dtype  =  torch .float16  if  TEST_FP8  else  torch .float32 
79+ 
80+         self .gate_proj  =  torch .nn .Parameter (
81+             torch .empty ((intermediate_size , hidden_size ), dtype = dtype ))
82+         self .norm  =  RMSNorm (intermediate_size , 1e-05 )
83+         self .norm .weight  =  torch .nn .Parameter (
84+             torch .ones (intermediate_size , dtype = dtype ))
85+ 
86+         torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
87+ 
88+         if  TEST_FP8 :
89+             self .fp8_linear  =  Fp8LinearOp (act_quant_static = True )
90+ 
91+             self .scale  =  torch .rand (1 , dtype = torch .float32 )
92+             self .w  =  torch .rand (hidden_size ,
93+                                 intermediate_size ).to (dtype = FP8_DTYPE ).t ()
94+             self .wscale  =  torch .rand (1 , dtype = torch .float32 )
95+ 
96+     def  forward (self , hidden_states , residual ):
97+         # Reshape input 
98+         view  =  hidden_states .reshape (- 1 , self .hidden_size )
99+ 
100+         # matrix multiplication 
101+         permute  =  self .gate_proj .permute (1 , 0 )
102+         mm  =  torch .mm (view , permute )
103+ 
104+         # layer normalization 
105+         norm_output , residual_output  =  self .norm (mm , residual )
106+ 
107+         if  TEST_FP8 :
108+             # scaled_mm with static input quantization 
109+             fp8_linear_result  =  self .fp8_linear .apply (
110+                 norm_output ,
111+                 self .w ,
112+                 self .wscale ,
113+                 input_scale = self .scale .to (norm_output .device ),
114+             )
115+ 
116+             return  fp8_linear_result , residual_output 
117+ 
118+         else :
119+             return  norm_output , residual_output 
120+ 
121+     def  example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
122+         dtype  =  torch .float16  if  TEST_FP8  else  torch .float32 
123+         hidden_states  =  torch .randn ((batch_size  *  seq_len , hidden_size ),
124+                                     dtype = dtype )
125+         residual  =  torch .randn ((batch_size  *  seq_len , hidden_size ),
126+                                dtype = dtype )
127+         return  (hidden_states , residual )
25128
26- RMS_OP  =  torch .ops ._C .rms_norm .default 
129+     def  ops_in_model (self , do_fusion ):
130+         if  TEST_FP8  and  do_fusion :
131+             return  [torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
132+         else :
133+             return  [torch .ops ._C .fused_add_rms_norm .default ]
27134
28- RMS_QUANT_OPS  =  {
29-     "static_fp8" : [
30-         torch .ops ._C .rms_norm_static_fp8_quant .default ,
31-         torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default 
32-     ],
33- }
135+     def  ops_not_in_model (self ):
136+         return  []
34137
35- SILU_MUL_OP  =  torch .ops ._C .silu_and_mul .default 
36138
37- SILU_MUL_QUANT_OP  =  torch .ops ._C .silu_and_mul_quant .default 
38- prompts  =  [
39-     "Hello, my name is" ,
40-     "The president of the United States is" ,
41-     "The capital of France is" ,
42-     "The future of AI is" ,
139+ class  TestRotaryEmbedding (torch .nn .Module ):
140+ 
141+     def  __init__ (self ,
142+                  head_dim = 64 ,
143+                  rotary_dim = None ,
144+                  max_position = 2048 ,
145+                  base = 10000 ):
146+         super ().__init__ ()
147+         self .head_dim  =  head_dim 
148+         self .rotary_dim  =  rotary_dim  or  head_dim 
149+ 
150+         self .rotary_emb  =  get_rope (
151+             self .head_dim ,
152+             rotary_dim = self .rotary_dim ,
153+             max_position = max_position ,
154+             base = base ,
155+         )
156+ 
157+     def  forward (self , positions , q , k ):
158+         q_rotated , k_rotated  =  self .rotary_emb (positions , q , k )
159+         return  q_rotated , k_rotated 
160+ 
161+     def  example_inputs (self , num_tokens = 32 , head_dim = 64 ):
162+         dtype  =  torch .float16 
163+         positions  =  torch .arange (num_tokens , dtype = torch .long )
164+         q  =  torch .randn (num_tokens , head_dim , dtype = dtype )
165+         k  =  torch .randn (num_tokens , head_dim , dtype = dtype )
166+         return  (positions , q , k )
167+ 
168+     def  ops_in_model (self , do_fusion ):
169+         return  [torch .ops ._C .rotary_embedding .default ]
170+ 
171+     def  ops_not_in_model (self ):
172+         return  []
173+ 
174+ 
175+ class  TestRotaryEmbeddingSliceScatter (torch .nn .Module ):
176+ 
177+     def  __init__ (self ,
178+                  head_dim = 64 ,
179+                  num_heads = 4 ,
180+                  max_position = 2048 ,
181+                  base = 10000 ):
182+         super ().__init__ ()
183+         self .head_dim  =  head_dim 
184+         self .num_heads  =  num_heads 
185+         self .hidden_size  =  head_dim  *  num_heads 
186+ 
187+         self .qkv_proj  =  torch .nn .Linear (self .hidden_size ,
188+                                         self .hidden_size  *  3 ,
189+                                         bias = False ,
190+                                         dtype = torch .float16 )
191+ 
192+         self .rotary_emb  =  get_rope (
193+             self .head_dim ,
194+             rotary_dim = self .head_dim ,
195+             max_position = max_position ,
196+             base = base ,
197+         )
198+ 
199+     def  forward (self , positions , hidden_states ):
200+         # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding 
201+         # -> slice_scatter -> split_with_sizes 
202+ 
203+         qkv  =  self .qkv_proj (hidden_states )
204+         split_sizes  =  [self .hidden_size , self .hidden_size , self .hidden_size ]
205+         q , k , v  =  torch .split (qkv , split_sizes , dim = - 1 )
206+ 
207+         q_rotated , k_rotated  =  self .rotary_emb (positions , q , k )
208+ 
209+         qkv_updated  =  torch .cat ([q_rotated , k_rotated , v ], dim = - 1 )
210+         return  qkv_updated 
211+ 
212+     def  example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
213+         dtype  =  torch .float16 
214+         hidden_size  =  head_dim  *  num_heads 
215+         positions  =  torch .arange (num_tokens , dtype = torch .long )
216+         hidden_states  =  torch .randn (num_tokens , hidden_size , dtype = dtype )
217+         return  (positions , hidden_states )
218+ 
219+     def  ops_in_model (self , do_fusion ):
220+         return  [torch .ops ._C .rotary_embedding .default ]
221+ 
222+     def  ops_not_in_model (self ):
223+         return  [torch .ops .aten .slice_scatter .default ]
224+ 
225+ 
226+ MODELS  =  [
227+     TestSiluMul ,
228+     TestFusedAddRMSNorm ,
229+     TestRotaryEmbedding ,
230+     TestRotaryEmbeddingSliceScatter ,
43231]
44232
45233
46- @pytest .mark .parametrize ( 
47-     "model, quant_key" , 
48-     [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" , kFp8StaticTensorSym ), 
49-      ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e" , 
50-       kFp8DynamicTokenSym )]) 
234+ @pytest .mark .parametrize ("model_class" , MODELS ) 
51235@pytest .mark .parametrize ("do_fusion" , [True , False ]) 
52236@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE  !=  "cuda" , 
53237                    reason = "Only test on CUDA" ) 
54- def  test_fix_functionalization (model : str , quant_key : QuantKey ,
55-                                do_fusion : bool ):
238+ def  test_fix_functionalization (model_class : torch .nn .Module , do_fusion : bool ):
56239    torch .set_default_device ("cuda" )
57240
58241    vllm_config  =  VllmConfig ()
@@ -63,56 +246,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63246    cleanup_pass  =  PostCleanupPass (vllm_config )
64247    act_quant_fusion_pass  =  ActivationQuantFusionPass (vllm_config )
65248
66-     passes  =  [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass 
67-               ]  if  do_fusion  else  [noop_pass , cleanup_pass ]
249+     passes  =  ( [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass ] 
250+               if  do_fusion  else  [noop_pass , cleanup_pass ]) 
68251    func_pass  =  FixFunctionalizationPass (vllm_config )
252+ 
69253    backend_func  =  TestBackend (* passes , func_pass )
70254    backend_no_func  =  TestBackend (* passes )
71255
72-     # instantiate a full engine and manually compile the model 2x 
73-     # (with and without FixFunctionalizationPass) 
74-     llm  =  LLM (model = model , enforce_eager = True )
75-     model_runner  =  llm .llm_engine .model_executor .driver_worker .model_runner 
76-     orig_model  =  model_runner .model 
77-     # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) 
78-     # Can only do that by using the decorator but then we'd have to instantiate 
79-     # 2 LLM instances. 
80- 
81-     sampling_params  =  SamplingParams (temperature = 0.0 , top_p = 1.0 )
82-     model_runner .model  =  torch .compile (orig_model ,
83-                                        fullgraph = True ,
84-                                        backend = backend_func )
85-     gen_func  =  llm .generate (prompts , sampling_params )
86- 
87-     model_runner .model  =  torch .compile (orig_model ,
88-                                        fullgraph = True ,
89-                                        backend = backend_no_func )
90- 
91-     gen_no_func  =  llm .generate (prompts , sampling_params )
92- 
93-     for  output_func , output_no_func  in  zip (gen_func , gen_no_func ):
94-         assert  output_func .outputs [0 ].text  ==  output_no_func .outputs [0 ].text 
95- 
96-     # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, 
97-     # and replaced by fused quantized ops in RMS_QUANT_OPS. 
98-     rms_ops  =  [FUSED_OPS [(quant_key , True )], FUSED_OPS [(quant_key , False )]
99-                ] if  do_fusion  else  [RMS_OP ]
100-     silu_mul_ops  =  [SILU_MUL_QUANT_OP ] if  do_fusion  and  \
101-         quant_key  ==  kFp8StaticTensorSym  else  [
102-         SILU_MUL_OP 
103-     ]
104- 
105-     ops  =  OPS_IN_MODEL  +  rms_ops  +  silu_mul_ops 
106- 
107-     for  op  in  ops :
256+     model  =  model_class ()
257+     torch .compile (model , backend = backend_func )(* model .example_inputs ())
258+     torch .compile (model , backend = backend_no_func )(* model .example_inputs ())
259+ 
260+     # check if the functionalization pass is applied 
261+     for  op  in  model .ops_in_model (do_fusion ):
108262        find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
109-         assert  find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,
110-                                    op )  is  None   # noqa: E501 
263+         assert  ( find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,  op ) 
264+                 is  None )   # noqa: E501 
111265
112266    # make sure the ops were all de-functionalized 
113267    found  =  dict ()
114268    for  node  in  backend_func .graph_post_pass .nodes :
115-         for  op  in  ops :
269+         for  op  in  model .ops_in_model (do_fusion ):
270+             if  is_func (node , op ):
271+                 found [op ] =  True 
272+         for  op  in  model .ops_not_in_model ():
116273            if  is_func (node , op ):
117274                found [op ] =  True 
118-     assert  all (found [op ] for  op  in  ops )
275+     assert  all (found [op ] for  op  in  model .ops_in_model (do_fusion ))
276+     assert  all (not  found .get (op ) for  op  in  model .ops_not_in_model ())
0 commit comments