22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
33import  itertools 
44from  typing  import  Callable 
5+ from  unittest .mock  import  patch 
56
7+ import  pandas  as  pd 
68import  torch 
79
8- from  vllm  import  _custom_ops  as  ops 
9- from  vllm .config  import  CompilationConfig , VllmConfig , set_current_vllm_config 
1010from  vllm .model_executor .layers .quantization .input_quant_fp8  import  QuantFP8 
1111from  vllm .model_executor .layers .quantization .utils .quant_utils  import  GroupShape 
1212from  vllm .triton_utils  import  triton 
13+ from  vllm .utils  import  STR_DTYPE_TO_TORCH_DTYPE , FlexibleArgumentParser 
14+ 
15+ 
16+ def  with_triton_mode (fn ):
17+     """Temporarily force the Triton fallback path""" 
18+ 
19+     def  wrapped (* args , ** kwargs ):
20+         with  patch ("vllm.platforms.current_platform.is_cuda" , return_value = False ):
21+             return  fn (* args , ** kwargs )
22+ 
23+     return  wrapped 
1324
1425
1526# TODO(luka): use standalone_compile utility 
@@ -21,78 +32,236 @@ def inner(*args):
2132    return  inner 
2233
2334
24- torch ._dynamo .config .recompile_limit  =  8888 
25- compilation_config  =  CompilationConfig (custom_ops = ["none" ])
26- with  set_current_vllm_config (VllmConfig (compilation_config = compilation_config )):
27-     torch_per_token_quant_fp8  =  torch .compile (
28-         QuantFP8 (False , GroupShape .PER_TOKEN ),
29-         fullgraph = True ,
30-         dynamic = False ,  # recompile for different shapes 
31-     )
35+ def  bench_compile (fn : Callable ):
36+     # recompile for different shapes 
37+     fwd  =  torch .compile (fn , fullgraph = True , dynamic = False )
3238
3339    # First dim is explicitly dynamic to simulate vLLM usage 
34-     torch_per_token_quant_fp8   =   with_dyn_arg (torch_per_token_quant_fp8 , 0 , 0 )
40+     return   with_dyn_arg (fwd , 0 , 0 )
3541
3642
37- def  cuda_per_token_quant_fp8 (
38-     input : torch .Tensor ,
39- ) ->  tuple [torch .Tensor , torch .Tensor ]:
40-     return  ops .scaled_fp8_quant (input )
43+ torch ._dynamo .config .recompile_limit  =  8888 
4144
4245
43- def  calculate_diff (batch_size : int , seq_len : int ):
44-     """Calculate difference between Triton and CUDA implementations.""" 
46+ def  calculate_diff (
47+     batch_size : int ,
48+     hidden_size : int ,
49+     group_shape : GroupShape ,
50+     dtype : torch .dtype ,
51+ ):
52+     """Calculate the difference between Inductor and CUDA implementations.""" 
4553    device  =  torch .device ("cuda" )
46-     x  =  torch .rand ((batch_size  *  seq_len , 4096 ), dtype = torch .float16 , device = device )
54+     x  =  torch .rand ((batch_size  *  hidden_size , 4096 ), dtype = dtype , device = device )
55+ 
56+     quant_fp8  =  QuantFP8 (False , group_shape , column_major_scales = False )
4757
48-     torch_out , torch_scale  =  torch_per_token_quant_fp8 (x )
49-     cuda_out , cuda_scale  =  cuda_per_token_quant_fp8 (x )
58+     torch_out , torch_scale  =  bench_compile (quant_fp8 .forward_native )(x )
59+     torch_eager_out , torch_eager_scale  =  quant_fp8 .forward_native (x )
60+     cuda_out , cuda_scale  =  quant_fp8 .forward_cuda (x )
5061
51-     if  torch .allclose (
52-         cuda_out .to (torch .float32 ), torch_out .to (torch .float32 ), rtol = 1e-3 , atol = 1e-5 
53-     ) and  torch .allclose (cuda_scale , torch_scale , rtol = 1e-3 , atol = 1e-5 ):
62+     out_allclose  =  lambda  o1 , o2 : torch .allclose (
63+         o1 .to (torch .float32 ),
64+         o2 .to (torch .float32 ),
65+         rtol = 1e-3 ,
66+         atol = 1e-5 ,
67+     )
68+     scale_allclose  =  lambda  s1 , s2 : torch .allclose (s1 , s2 , rtol = 1e-3 , atol = 1e-5 )
69+ 
70+     if  (
71+         out_allclose (cuda_out , torch_out )
72+         and  scale_allclose (cuda_scale , torch_scale )
73+         and  out_allclose (cuda_out , torch_eager_out )
74+         and  scale_allclose (cuda_scale , torch_eager_scale )
75+     ):
5476        print ("✅ All implementations match" )
5577    else :
5678        print ("❌ Implementations differ" )
5779
5880
59- batch_size_range  =  [1 , 16 , 32 , 64 , 128 ]
60- seq_len_range  =  [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
61- 
62- configs  =  list (itertools .product (batch_size_range , seq_len_range ))
81+ configs  =  []
6382
6483
65- @triton .testing .perf_report ( 
66-     triton .testing .Benchmark ( 
67-         x_names = ["batch_size" , "seq_len" ], 
68-         x_vals = configs , 
69-         line_arg = "provider" , 
70-         line_vals = ["torch" , "cuda" ], 
71-         line_names = ["Torch" , "CUDA" ], 
72-         styles = [("blue" , "-" ), ("green" , "-" )], 
73-         ylabel = "us" , 
74-         plot_name = "per-token-dynamic-quant-fp8-performance" , 
75-         args = {}, 
76-     ) 
77- ) 
78- def  benchmark_quantization (batch_size , seq_len , provider ):
79-     dtype  =  torch .float16 
84+ def  benchmark_quantization (
85+     batch_size ,
86+     hidden_size ,
87+     provider ,
88+     group_shape : GroupShape ,
89+     col_major : bool ,
90+     dtype : torch .dtype ,
91+ ):
8092    device  =  torch .device ("cuda" )
8193
82-     x  =  torch .randn (batch_size  *  seq_len , 4096 , device = device , dtype = dtype )
94+     x  =  torch .randn (batch_size  *  hidden_size , 4096 , device = device , dtype = dtype )
8395
8496    quantiles  =  [0.5 , 0.2 , 0.8 ]
97+     quant_fp8  =  QuantFP8 (False , group_shape , column_major_scales = col_major )
8598
8699    if  provider  ==  "torch" :
87-         fn  =  lambda : torch_per_token_quant_fp8 (x .clone ())
100+         fn  =  lambda : bench_compile ( quant_fp8 . forward_native ) (x .clone ())
88101    elif  provider  ==  "cuda" :
89-         fn  =  lambda : cuda_per_token_quant_fp8 (x .clone ())
102+         fn  =  lambda : quant_fp8 .forward_cuda (x .clone ())
103+     elif  provider  ==  "triton" :
104+         if  not  group_shape .is_per_group ():
105+             # Triton only supported for per-group 
106+             return  0 , 0 , 0 
107+ 
108+         fn  =  lambda : with_triton_mode (quant_fp8 .forward_cuda )(x .clone ())
90109
91110    ms , min_ms , max_ms  =  triton .testing .do_bench_cudagraph (fn , quantiles = quantiles )
92111
93112    return  1000  *  ms , 1000  *  max_ms , 1000  *  min_ms 
94113
95114
115+ # TODO(luka) extract to utils 
116+ def  compute_geomean_speedups (
117+     df : pd .DataFrame ,
118+     baseline_col : str ,
119+     speedup_cols : list [str ],
120+     groupby_cols : list [str ] |  None  =  None ,
121+ ) ->  pd .DataFrame :
122+     """ 
123+     Compute geometric mean speedups over a baseline column. 
124+ 
125+     Args: 
126+         df: Input dataframe 
127+         baseline_col: Column to use as baseline 
128+         speedup_cols: Columns to compute speedups for 
129+         groupby_cols: Columns to group by. If None, compute over entire df. 
130+ 
131+     Returns: 
132+         pd.DataFrame with geometric mean speedups 
133+     """ 
134+     from  scipy .stats  import  gmean 
135+ 
136+     def  geo_speedup (group : pd .DataFrame ) ->  pd .Series :
137+         ratios  =  {
138+             col : (group [baseline_col ] /  group [col ]).values  for  col  in  speedup_cols 
139+         }
140+         return  pd .Series ({col : gmean (vals ) for  col , vals  in  ratios .items ()})
141+ 
142+     if  groupby_cols  is  None :
143+         result  =  geo_speedup (df ).to_frame ().T 
144+     else :
145+         result  =  (
146+             df .groupby (groupby_cols )
147+             .apply (geo_speedup , include_groups = False )
148+             .reset_index ()
149+         )
150+ 
151+     return  result 
152+ 
153+ 
96154if  __name__  ==  "__main__" :
97-     calculate_diff (batch_size = 4 , seq_len = 4096 )
98-     benchmark_quantization .run (print_data = True )
155+     parser  =  FlexibleArgumentParser (
156+         description = "Benchmark the various implementations of QuantFP8 (dynamic-only)" 
157+     )
158+     parser .add_argument ("-c" , "--check" , action = "store_true" )
159+     parser .add_argument (
160+         "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "half" 
161+     )
162+     parser .add_argument (
163+         "--hidden-sizes" ,
164+         type = int ,
165+         nargs = "+" ,
166+         default = None ,
167+         help = "Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)" ,
168+     )
169+     parser .add_argument (
170+         "--batch-sizes" ,
171+         type = int ,
172+         nargs = "+" ,
173+         default = None ,
174+         help = "Batch sizes to benchmark (default: 1,16,32,64,128)" ,
175+     )
176+     parser .add_argument (
177+         "--group-sizes" ,
178+         type = int ,
179+         nargs = "+" ,
180+         default = None ,
181+         help = "Group sizes for GroupShape(1,N) to benchmark. " 
182+         "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)" ,
183+     )
184+     parser .add_argument (
185+         "--no-column-major" ,
186+         action = "store_true" ,
187+         help = "Disable column-major scales testing" ,
188+     )
189+ 
190+     args  =  parser .parse_args ()
191+     assert  args 
192+ 
193+     dtype  =  STR_DTYPE_TO_TORCH_DTYPE [args .dtype ]
194+ 
195+     hidden_sizes  =  args .hidden_sizes  or  [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
196+     batch_sizes  =  args .batch_sizes  or  [1 , 16 , 32 , 64 , 128 ]
197+ 
198+     if  args .group_sizes  is  not None :
199+         group_shapes  =  []
200+         for  size  in  args .group_sizes :
201+             if  size  ==  0 :
202+                 group_shapes .append (GroupShape .PER_TENSOR )
203+             elif  size  ==  - 1 :
204+                 group_shapes .append (GroupShape .PER_TOKEN )
205+             else :
206+                 group_shapes .append (GroupShape (1 , size ))
207+     else :
208+         group_shapes  =  [
209+             GroupShape .PER_TENSOR ,
210+             GroupShape .PER_TOKEN ,
211+             GroupShape (1 , 64 ),
212+             GroupShape (1 , 128 ),
213+         ]
214+ 
215+     column_major_scales  =  [False ] if  args .no_column_major  else  [True , False ]
216+ 
217+     config_gen  =  itertools .product (
218+         group_shapes ,
219+         column_major_scales ,
220+         batch_sizes ,
221+         hidden_sizes ,
222+     )
223+ 
224+     # filter out column-major scales for non-group, reverse order 
225+     configs .extend (c [::- 1 ] for  c  in  config_gen  if  (c [0 ].is_per_group () or  not  c [1 ]))
226+ 
227+     print (f"Running { len (configs )}  )
228+     print (f"  Hidden sizes: { hidden_sizes }  )
229+     print (f"  Batch sizes: { batch_sizes }  )
230+     print (f"  Group shapes: { [str (g ) for  g  in  group_shapes ]}  )
231+     print (f"  Column major scales: { column_major_scales }  )
232+     print ()
233+ 
234+     if  args .check :
235+         for  group_shape  in  group_shapes :
236+             group_size  =  group_shape [1 ]
237+             print (f"{ group_size = }  )
238+             calculate_diff (
239+                 batch_size = 4 , hidden_size = 4096 , group_shape = group_shape , dtype = dtype 
240+             )
241+ 
242+     benchmark  =  triton .testing .perf_report (
243+         triton .testing .Benchmark (
244+             x_names = ["hidden_size" , "batch_size" , "col_major" , "group_shape" ],
245+             x_vals = configs ,
246+             line_arg = "provider" ,
247+             line_vals = ["torch" , "cuda" , "triton" ],
248+             line_names = ["Torch (Compiled)" , "CUDA" , "Triton" ],
249+             styles = [("blue" , "-" ), ("green" , "-" ), ("black" , "-" )],
250+             ylabel = "us" ,
251+             plot_name = "QuantFP8 performance" ,
252+             args = {},
253+         )
254+     )(benchmark_quantization )
255+ 
256+     df  =  benchmark .run (print_data = True , dtype = dtype , return_df = True )
257+ 
258+     # Print geomean speedups 
259+     geo_table_grouped  =  compute_geomean_speedups (
260+         df ,
261+         baseline_col = "Torch (Compiled)" ,
262+         speedup_cols = ["CUDA" , "Triton" ],
263+         groupby_cols = ["col_major" , "group_shape" ],
264+     )
265+ 
266+     print ("Speedup over Torch (Compiled)" )
267+     print (geo_table_grouped .to_string (index = False ))
0 commit comments