55import  torch 
66
77import  vllm .envs  as  envs 
8- from  vllm  import  LLM , SamplingParams 
98from  vllm .compilation .activation_quant_fusion  import  ActivationQuantFusionPass 
109from  vllm .compilation .fix_functionalization  import  FixFunctionalizationPass 
11- from  vllm .compilation .fusion  import  FUSED_OPS ,  RMSNormQuantFusionPass 
10+ from  vllm .compilation .fusion  import  RMSNormQuantFusionPass 
1211from  vllm .compilation .fx_utils  import  find_auto_fn , find_auto_fn_maybe , is_func 
1312from  vllm .compilation .noop_elimination  import  NoOpEliminationPass 
1413from  vllm .compilation .post_cleanup  import  PostCleanupPass 
1514from  vllm .config  import  CompilationConfig , PassConfig , VllmConfig 
15+ from  vllm .model_executor .layers .activation  import  SiluAndMul 
16+ from  vllm .model_executor .layers .layernorm  import  RMSNorm 
1617from  vllm .model_executor .layers .quantization .utils .quant_utils  import  (
17-     QuantKey , kFp8DynamicTokenSym , kFp8StaticTensorSym )
18+     GroupShape )
19+ from  vllm .model_executor .layers .quantization .utils .w8a8_utils  import  (
20+     Fp8LinearOp )
21+ from  vllm .model_executor .layers .rotary_embedding  import  get_rope 
22+ from  vllm .platforms  import  current_platform 
1823
1924from  .backend  import  TestBackend 
2025
26+ FP8_DTYPE  =  current_platform .fp8_dtype ()
2127OPS_IN_MODEL  =  [
2228    torch .ops ._C .rotary_embedding .default ,
2329    torch .ops ._C .fused_add_rms_norm .default ,
2834RMS_QUANT_OPS  =  {
2935    "static_fp8" : [
3036        torch .ops ._C .rms_norm_static_fp8_quant .default ,
31-         torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default 
37+         torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default , 
3238    ],
3339}
3440
4349]
4450
4551
46- @pytest .mark .parametrize ( 
47-     "model, quant_key" , 
48-     [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" , kFp8StaticTensorSym ), 
49-      ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e" , 
50-       kFp8DynamicTokenSym )]) 
52+ class  TestSiluMul (torch .nn .Module ):
53+ 
54+     def  __init__ (self , quant = True , hidden_size : int  =  128 ):
55+         super ().__init__ ()
56+         self .quant  =  quant 
57+         self .silu_and_mul  =  SiluAndMul ()
58+         self .wscale  =  torch .rand (1 , dtype = torch .float32 )
59+         self .scale  =  torch .rand (1 , dtype = torch .float32 )
60+ 
61+         if  self .quant :
62+             self .w  =  torch .rand (hidden_size ,
63+                                 hidden_size ).to (dtype = FP8_DTYPE ).t ()
64+             self .fp8_linear  =  Fp8LinearOp (
65+                 act_quant_static = True ,
66+                 act_quant_group_shape = GroupShape .PER_TENSOR ,
67+             )
68+ 
69+     def  forward (self , x ):
70+         y  =  self .silu_and_mul (x )
71+         if  self .quant :
72+             x2  =  self .fp8_linear .apply (y ,
73+                                        self .w ,
74+                                        self .wscale ,
75+                                        input_scale = self .wscale )
76+             return  x2 
77+         else :
78+             return  y 
79+ 
80+     def  example_inputs (self , num_tokens = 32 , hidden_size = 128 ):
81+         dtype  =  torch .float16  if  self .quant  else  torch .float32 
82+         return  (torch .rand (num_tokens , hidden_size  *  2 , dtype = dtype ), )
83+ 
84+     def  ops_in_model (self ):
85+         return  ([torch .ops ._C .silu_and_mul_quant .default ]
86+                 if  self .quant  else  [torch .ops ._C .silu_and_mul .default ])
87+ 
88+     def  ops_not_in_model (self ):
89+         return  []
90+ 
91+ 
92+ class  TestFusedAddRMSNorm (torch .nn .Module ):
93+ 
94+     def  __init__ (self , quant = True , hidden_size = 16 , intermediate_size = 32 ):
95+         super ().__init__ ()
96+         self .quant  =  quant 
97+         self .hidden_size  =  hidden_size 
98+         self .intermediate_size  =  intermediate_size 
99+ 
100+         dtype  =  torch .float16  if  self .quant  else  torch .float32 
101+ 
102+         self .gate_proj  =  torch .nn .Parameter (
103+             torch .empty ((intermediate_size , hidden_size ), dtype = dtype ))
104+         self .norm  =  RMSNorm (intermediate_size , 1e-05 )
105+         self .norm .weight  =  torch .nn .Parameter (
106+             torch .ones (intermediate_size , dtype = dtype ))
107+ 
108+         torch .nn .init .normal_ (self .gate_proj , std = 0.02 )
109+ 
110+         if  self .quant :
111+             self .fp8_linear  =  Fp8LinearOp (act_quant_static = True )
112+ 
113+             self .scale  =  torch .rand (1 , dtype = torch .float32 )
114+             self .w  =  torch .rand (hidden_size ,
115+                                 intermediate_size ).to (dtype = FP8_DTYPE ).t ()
116+             self .wscale  =  torch .rand (1 , dtype = torch .float32 )
117+ 
118+     def  forward (self , hidden_states , residual ):
119+         # Reshape input 
120+         view  =  hidden_states .reshape (- 1 , self .hidden_size )
121+ 
122+         # matrix multiplication 
123+         permute  =  self .gate_proj .permute (1 , 0 )
124+         mm  =  torch .mm (view , permute )
125+ 
126+         # layer normalization 
127+         norm_output , residual_output  =  self .norm (mm , residual )
128+ 
129+         if  self .quant :
130+             # scaled_mm with static input quantization 
131+             fp8_linear_result  =  self .fp8_linear .apply (
132+                 norm_output ,
133+                 self .w ,
134+                 self .wscale ,
135+                 input_scale = self .scale .to (norm_output .device ),
136+             )
137+ 
138+             return  fp8_linear_result , residual_output 
139+ 
140+         else :
141+             return  norm_output , residual_output 
142+ 
143+     def  example_inputs (self , batch_size = 8 , hidden_size = 16 , seq_len = 16 ):
144+         dtype  =  torch .float16  if  self .quant  else  torch .float32 
145+         hidden_states  =  torch .randn ((batch_size  *  seq_len , hidden_size ),
146+                                     dtype = dtype )
147+         residual  =  torch .randn ((batch_size  *  seq_len , hidden_size ),
148+                                dtype = dtype )
149+         return  (hidden_states , residual )
150+ 
151+     def  ops_in_model (self ):
152+         return  ([torch .ops ._C .fused_add_rms_norm_static_fp8_quant .default ]
153+                 if  self .quant  else  [torch .ops ._C .fused_add_rms_norm .default ])
154+ 
155+     def  ops_not_in_model (self ):
156+         return  []
157+ 
158+ 
159+ class  TestRotaryEmbedding (torch .nn .Module ):
160+ 
161+     def  __init__ (
162+             self ,
163+             quant = False ,  # not used 
164+             head_dim = 64 ,
165+             rotary_dim = None ,
166+             max_position = 2048 ,
167+             base = 10000 ):
168+         super ().__init__ ()
169+         self .head_dim  =  head_dim 
170+         self .rotary_dim  =  rotary_dim  or  head_dim 
171+ 
172+         self .rotary_emb  =  get_rope (
173+             self .head_dim ,
174+             rotary_dim = self .rotary_dim ,
175+             max_position = max_position ,
176+             base = base ,
177+         )
178+ 
179+     def  forward (self , positions , q , k ):
180+         q_rotated , k_rotated  =  self .rotary_emb (positions , q , k )
181+         return  q_rotated , k_rotated 
182+ 
183+     def  example_inputs (self , num_tokens = 32 , head_dim = 64 ):
184+         dtype  =  torch .float16 
185+         positions  =  torch .arange (num_tokens , dtype = torch .long )
186+         q  =  torch .randn (num_tokens , head_dim , dtype = dtype )
187+         k  =  torch .randn (num_tokens , head_dim , dtype = dtype )
188+         return  (positions , q , k )
189+ 
190+     def  ops_in_model (self ):
191+         return  [torch .ops ._C .rotary_embedding .default ]
192+ 
193+     def  ops_not_in_model (self ):
194+         return  []
195+ 
196+ 
197+ class  TestRotaryEmbeddingSliceScatter (torch .nn .Module ):
198+ 
199+     def  __init__ (
200+             self ,
201+             quant = False ,  # not used 
202+             head_dim = 64 ,
203+             num_heads = 4 ,
204+             max_position = 2048 ,
205+             base = 10000 ):
206+         super ().__init__ ()
207+         self .head_dim  =  head_dim 
208+         self .num_heads  =  num_heads 
209+         self .hidden_size  =  head_dim  *  num_heads 
210+ 
211+         self .qkv_proj  =  torch .nn .Linear (self .hidden_size ,
212+                                         self .hidden_size  *  3 ,
213+                                         bias = False ,
214+                                         dtype = torch .float16 )
215+ 
216+         self .rotary_emb  =  get_rope (
217+             self .head_dim ,
218+             rotary_dim = self .head_dim ,
219+             max_position = max_position ,
220+             base = base ,
221+         )
222+ 
223+     def  forward (self , positions , hidden_states ):
224+         # Simulate the pattern: mm -> split_with_sizes -> rotary_embedding 
225+         # -> slice_scatter -> split_with_sizes 
226+ 
227+         qkv  =  self .qkv_proj (hidden_states )
228+         split_sizes  =  [self .hidden_size , self .hidden_size , self .hidden_size ]
229+         q , k , v  =  torch .split (qkv , split_sizes , dim = - 1 )
230+ 
231+         q_rotated , k_rotated  =  self .rotary_emb (positions , q , k )
232+ 
233+         qkv_updated  =  torch .cat ([q_rotated , k_rotated , v ], dim = - 1 )
234+         return  qkv_updated 
235+ 
236+     def  example_inputs (self , num_tokens = 32 , head_dim = 64 , num_heads = 4 ):
237+         dtype  =  torch .float16 
238+         hidden_size  =  head_dim  *  num_heads 
239+         positions  =  torch .arange (num_tokens , dtype = torch .long )
240+         hidden_states  =  torch .randn (num_tokens , hidden_size , dtype = dtype )
241+         return  (positions , hidden_states )
242+ 
243+     def  ops_in_model (self ):
244+         return  [torch .ops ._C .rotary_embedding .default ]
245+ 
246+     def  ops_not_in_model (self ):
247+         return  [torch .ops .aten .slice_scatter .default ]
248+ 
249+ 
250+ MODELS  =  [
251+     TestSiluMul ,
252+     TestFusedAddRMSNorm ,
253+     TestRotaryEmbedding ,
254+     TestRotaryEmbeddingSliceScatter ,
255+ ]
256+ 
257+ 
258+ @pytest .mark .parametrize ("model_class" , MODELS ) 
259+ @pytest .mark .parametrize ("quant" , [True , False ]) 
51260@pytest .mark .parametrize ("do_fusion" , [True , False ]) 
52261@pytest .mark .skipif (envs .VLLM_TARGET_DEVICE  !=  "cuda" , 
53262                    reason = "Only test on CUDA" ) 
54- def  test_fix_functionalization (model :  str ,  quant_key :  QuantKey ,
263+ def  test_fix_functionalization (model_class :  torch . nn . Module ,  quant :  bool ,
55264                               do_fusion : bool ):
56265    torch .set_default_device ("cuda" )
266+     if  quant  and  not  current_platform .supports_fp8 ():
267+         pytest .skip ("FP8 is not supported on this platform" )
57268
58269    vllm_config  =  VllmConfig ()
59270    vllm_config .compilation_config  =  CompilationConfig (
@@ -63,56 +274,31 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
63274    cleanup_pass  =  PostCleanupPass (vllm_config )
64275    act_quant_fusion_pass  =  ActivationQuantFusionPass (vllm_config )
65276
66-     passes  =  [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass 
67-               ]  if  do_fusion  else  [noop_pass , cleanup_pass ]
277+     passes  =  ( [noop_pass , fusion_pass , act_quant_fusion_pass , cleanup_pass ] 
278+               if  do_fusion  else  [noop_pass , cleanup_pass ]) 
68279    func_pass  =  FixFunctionalizationPass (vllm_config )
280+ 
69281    backend_func  =  TestBackend (* passes , func_pass )
70282    backend_no_func  =  TestBackend (* passes )
71283
72-     # instantiate a full engine and manually compile the model 2x 
73-     # (with and without FixFunctionalizationPass) 
74-     llm  =  LLM (model = model , enforce_eager = True )
75-     model_runner  =  llm .llm_engine .model_executor .driver_worker .model_runner 
76-     orig_model  =  model_runner .model 
77-     # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) 
78-     # Can only do that by using the decorator but then we'd have to instantiate 
79-     # 2 LLM instances. 
80- 
81-     sampling_params  =  SamplingParams (temperature = 0.0 , top_p = 1.0 )
82-     model_runner .model  =  torch .compile (orig_model ,
83-                                        fullgraph = True ,
84-                                        backend = backend_func )
85-     gen_func  =  llm .generate (prompts , sampling_params )
86- 
87-     model_runner .model  =  torch .compile (orig_model ,
88-                                        fullgraph = True ,
89-                                        backend = backend_no_func )
90- 
91-     gen_no_func  =  llm .generate (prompts , sampling_params )
92- 
93-     for  output_func , output_no_func  in  zip (gen_func , gen_no_func ):
94-         assert  output_func .outputs [0 ].text  ==  output_no_func .outputs [0 ].text 
95- 
96-     # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, 
97-     # and replaced by fused quantized ops in RMS_QUANT_OPS. 
98-     rms_ops  =  [FUSED_OPS [(quant_key , True )], FUSED_OPS [(quant_key , False )]
99-                ] if  do_fusion  else  [RMS_OP ]
100-     silu_mul_ops  =  [SILU_MUL_QUANT_OP ] if  do_fusion  and  \
101-         quant_key  ==  kFp8StaticTensorSym  else  [
102-         SILU_MUL_OP 
103-     ]
104- 
105-     ops  =  OPS_IN_MODEL  +  rms_ops  +  silu_mul_ops 
106- 
107-     for  op  in  ops :
284+     model  =  model_class (quant = quant )
285+     torch .compile (model , backend = backend_func )(* model .example_inputs ())
286+     torch .compile (model , backend = backend_no_func )(* model .example_inputs ())
287+ 
288+     # check if the functionalization pass is applied 
289+     for  op  in  model .ops_in_model ():
108290        find_auto_fn (backend_no_func .graph_post_pass .nodes , op )
109-         assert  find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,
110-                                    op )  is  None   # noqa: E501 
291+         assert  ( find_auto_fn_maybe (backend_func .graph_post_pass .nodes ,  op ) 
292+                 is  None )   # noqa: E501 
111293
112294    # make sure the ops were all de-functionalized 
113295    found  =  dict ()
114296    for  node  in  backend_func .graph_post_pass .nodes :
115-         for  op  in  ops :
297+         for  op  in  model .ops_in_model ():
298+             if  is_func (node , op ):
299+                 found [op ] =  True 
300+         for  op  in  model .ops_not_in_model ():
116301            if  is_func (node , op ):
117302                found [op ] =  True 
118-     assert  all (found [op ] for  op  in  ops )
303+     assert  all (found [op ] for  op  in  model .ops_in_model ())
304+     assert  all (not  found .get (op ) for  op  in  model .ops_not_in_model ())
0 commit comments