22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import itertools
44from typing import Callable
5+ from unittest .mock import patch
56
7+ import pandas as pd
68import torch
79
8- from vllm import _custom_ops as ops
9- from vllm .config import CompilationConfig , VllmConfig , set_current_vllm_config
1010from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
1111from vllm .model_executor .layers .quantization .utils .quant_utils import GroupShape
1212from vllm .triton_utils import triton
13+ from vllm .utils import STR_DTYPE_TO_TORCH_DTYPE , FlexibleArgumentParser
14+
15+
16+ def with_triton_mode (fn ):
17+ """Temporarily force the Triton fallback path"""
18+
19+ def wrapped (* args , ** kwargs ):
20+ with patch ("vllm.platforms.current_platform.is_cuda" , return_value = False ):
21+ return fn (* args , ** kwargs )
22+
23+ return wrapped
1324
1425
1526# TODO(luka): use standalone_compile utility
@@ -21,78 +32,236 @@ def inner(*args):
2132 return inner
2233
2334
24- torch ._dynamo .config .recompile_limit = 8888
25- compilation_config = CompilationConfig (custom_ops = ["none" ])
26- with set_current_vllm_config (VllmConfig (compilation_config = compilation_config )):
27- torch_per_token_quant_fp8 = torch .compile (
28- QuantFP8 (False , GroupShape .PER_TOKEN ),
29- fullgraph = True ,
30- dynamic = False , # recompile for different shapes
31- )
35+ def bench_compile (fn : Callable ):
36+ # recompile for different shapes
37+ fwd = torch .compile (fn , fullgraph = True , dynamic = False )
3238
3339 # First dim is explicitly dynamic to simulate vLLM usage
34- torch_per_token_quant_fp8 = with_dyn_arg (torch_per_token_quant_fp8 , 0 , 0 )
40+ return with_dyn_arg (fwd , 0 , 0 )
3541
3642
37- def cuda_per_token_quant_fp8 (
38- input : torch .Tensor ,
39- ) -> tuple [torch .Tensor , torch .Tensor ]:
40- return ops .scaled_fp8_quant (input )
43+ torch ._dynamo .config .recompile_limit = 8888
4144
4245
43- def calculate_diff (batch_size : int , seq_len : int ):
44- """Calculate difference between Triton and CUDA implementations."""
46+ def calculate_diff (
47+ batch_size : int ,
48+ hidden_size : int ,
49+ group_shape : GroupShape ,
50+ dtype : torch .dtype ,
51+ ):
52+ """Calculate the difference between Inductor and CUDA implementations."""
4553 device = torch .device ("cuda" )
46- x = torch .rand ((batch_size * seq_len , 4096 ), dtype = torch .float16 , device = device )
54+ x = torch .rand ((batch_size * hidden_size , 4096 ), dtype = dtype , device = device )
55+
56+ quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = False )
4757
48- torch_out , torch_scale = torch_per_token_quant_fp8 (x )
49- cuda_out , cuda_scale = cuda_per_token_quant_fp8 (x )
58+ torch_out , torch_scale = bench_compile (quant_fp8 .forward_native )(x )
59+ torch_eager_out , torch_eager_scale = quant_fp8 .forward_native (x )
60+ cuda_out , cuda_scale = quant_fp8 .forward_cuda (x )
5061
51- if torch .allclose (
52- cuda_out .to (torch .float32 ), torch_out .to (torch .float32 ), rtol = 1e-3 , atol = 1e-5
53- ) and torch .allclose (cuda_scale , torch_scale , rtol = 1e-3 , atol = 1e-5 ):
62+ out_allclose = lambda o1 , o2 : torch .allclose (
63+ o1 .to (torch .float32 ),
64+ o2 .to (torch .float32 ),
65+ rtol = 1e-3 ,
66+ atol = 1e-5 ,
67+ )
68+ scale_allclose = lambda s1 , s2 : torch .allclose (s1 , s2 , rtol = 1e-3 , atol = 1e-5 )
69+
70+ if (
71+ out_allclose (cuda_out , torch_out )
72+ and scale_allclose (cuda_scale , torch_scale )
73+ and out_allclose (cuda_out , torch_eager_out )
74+ and scale_allclose (cuda_scale , torch_eager_scale )
75+ ):
5476 print ("✅ All implementations match" )
5577 else :
5678 print ("❌ Implementations differ" )
5779
5880
59- batch_size_range = [1 , 16 , 32 , 64 , 128 ]
60- seq_len_range = [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
61-
62- configs = list (itertools .product (batch_size_range , seq_len_range ))
81+ configs = []
6382
6483
65- @triton .testing .perf_report (
66- triton .testing .Benchmark (
67- x_names = ["batch_size" , "seq_len" ],
68- x_vals = configs ,
69- line_arg = "provider" ,
70- line_vals = ["torch" , "cuda" ],
71- line_names = ["Torch" , "CUDA" ],
72- styles = [("blue" , "-" ), ("green" , "-" )],
73- ylabel = "us" ,
74- plot_name = "per-token-dynamic-quant-fp8-performance" ,
75- args = {},
76- )
77- )
78- def benchmark_quantization (batch_size , seq_len , provider ):
79- dtype = torch .float16
84+ def benchmark_quantization (
85+ batch_size ,
86+ hidden_size ,
87+ provider ,
88+ group_shape : GroupShape ,
89+ col_major : bool ,
90+ dtype : torch .dtype ,
91+ ):
8092 device = torch .device ("cuda" )
8193
82- x = torch .randn (batch_size * seq_len , 4096 , device = device , dtype = dtype )
94+ x = torch .randn (batch_size * hidden_size , 4096 , device = device , dtype = dtype )
8395
8496 quantiles = [0.5 , 0.2 , 0.8 ]
97+ quant_fp8 = QuantFP8 (False , group_shape , column_major_scales = col_major )
8598
8699 if provider == "torch" :
87- fn = lambda : torch_per_token_quant_fp8 (x .clone ())
100+ fn = lambda : bench_compile ( quant_fp8 . forward_native ) (x .clone ())
88101 elif provider == "cuda" :
89- fn = lambda : cuda_per_token_quant_fp8 (x .clone ())
102+ fn = lambda : quant_fp8 .forward_cuda (x .clone ())
103+ elif provider == "triton" :
104+ if not group_shape .is_per_group ():
105+ # Triton only supported for per-group
106+ return 0 , 0 , 0
107+
108+ fn = lambda : with_triton_mode (quant_fp8 .forward_cuda )(x .clone ())
90109
91110 ms , min_ms , max_ms = triton .testing .do_bench_cudagraph (fn , quantiles = quantiles )
92111
93112 return 1000 * ms , 1000 * max_ms , 1000 * min_ms
94113
95114
115+ # TODO(luka) extract to utils
116+ def compute_geomean_speedups (
117+ df : pd .DataFrame ,
118+ baseline_col : str ,
119+ speedup_cols : list [str ],
120+ groupby_cols : list [str ] | None = None ,
121+ ) -> pd .DataFrame :
122+ """
123+ Compute geometric mean speedups over a baseline column.
124+
125+ Args:
126+ df: Input dataframe
127+ baseline_col: Column to use as baseline
128+ speedup_cols: Columns to compute speedups for
129+ groupby_cols: Columns to group by. If None, compute over entire df.
130+
131+ Returns:
132+ pd.DataFrame with geometric mean speedups
133+ """
134+ from scipy .stats import gmean
135+
136+ def geo_speedup (group : pd .DataFrame ) -> pd .Series :
137+ ratios = {
138+ col : (group [baseline_col ] / group [col ]).values for col in speedup_cols
139+ }
140+ return pd .Series ({col : gmean (vals ) for col , vals in ratios .items ()})
141+
142+ if groupby_cols is None :
143+ result = geo_speedup (df ).to_frame ().T
144+ else :
145+ result = (
146+ df .groupby (groupby_cols )
147+ .apply (geo_speedup , include_groups = False )
148+ .reset_index ()
149+ )
150+
151+ return result
152+
153+
96154if __name__ == "__main__" :
97- calculate_diff (batch_size = 4 , seq_len = 4096 )
98- benchmark_quantization .run (print_data = True )
155+ parser = FlexibleArgumentParser (
156+ description = "Benchmark the various implementations of QuantFP8 (dynamic-only)"
157+ )
158+ parser .add_argument ("-c" , "--check" , action = "store_true" )
159+ parser .add_argument (
160+ "--dtype" , type = str , choices = ["half" , "bfloat16" , "float" ], default = "half"
161+ )
162+ parser .add_argument (
163+ "--hidden-sizes" ,
164+ type = int ,
165+ nargs = "+" ,
166+ default = None ,
167+ help = "Hidden sizes to benchmark (default: 1,16,64,128,256,512,1024,2048,4096)" ,
168+ )
169+ parser .add_argument (
170+ "--batch-sizes" ,
171+ type = int ,
172+ nargs = "+" ,
173+ default = None ,
174+ help = "Batch sizes to benchmark (default: 1,16,32,64,128)" ,
175+ )
176+ parser .add_argument (
177+ "--group-sizes" ,
178+ type = int ,
179+ nargs = "+" ,
180+ default = None ,
181+ help = "Group sizes for GroupShape(1,N) to benchmark. "
182+ "Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)" ,
183+ )
184+ parser .add_argument (
185+ "--no-column-major" ,
186+ action = "store_true" ,
187+ help = "Disable column-major scales testing" ,
188+ )
189+
190+ args = parser .parse_args ()
191+ assert args
192+
193+ dtype = STR_DTYPE_TO_TORCH_DTYPE [args .dtype ]
194+
195+ hidden_sizes = args .hidden_sizes or [1 , 16 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ]
196+ batch_sizes = args .batch_sizes or [1 , 16 , 32 , 64 , 128 ]
197+
198+ if args .group_sizes is not None :
199+ group_shapes = []
200+ for size in args .group_sizes :
201+ if size == 0 :
202+ group_shapes .append (GroupShape .PER_TENSOR )
203+ elif size == - 1 :
204+ group_shapes .append (GroupShape .PER_TOKEN )
205+ else :
206+ group_shapes .append (GroupShape (1 , size ))
207+ else :
208+ group_shapes = [
209+ GroupShape .PER_TENSOR ,
210+ GroupShape .PER_TOKEN ,
211+ GroupShape (1 , 64 ),
212+ GroupShape (1 , 128 ),
213+ ]
214+
215+ column_major_scales = [False ] if args .no_column_major else [True , False ]
216+
217+ config_gen = itertools .product (
218+ group_shapes ,
219+ column_major_scales ,
220+ batch_sizes ,
221+ hidden_sizes ,
222+ )
223+
224+ # filter out column-major scales for non-group, reverse order
225+ configs .extend (c [::- 1 ] for c in config_gen if (c [0 ].is_per_group () or not c [1 ]))
226+
227+ print (f"Running { len (configs )} configurations:" )
228+ print (f" Hidden sizes: { hidden_sizes } " )
229+ print (f" Batch sizes: { batch_sizes } " )
230+ print (f" Group shapes: { [str (g ) for g in group_shapes ]} " )
231+ print (f" Column major scales: { column_major_scales } " )
232+ print ()
233+
234+ if args .check :
235+ for group_shape in group_shapes :
236+ group_size = group_shape [1 ]
237+ print (f"{ group_size = } " )
238+ calculate_diff (
239+ batch_size = 4 , hidden_size = 4096 , group_shape = group_shape , dtype = dtype
240+ )
241+
242+ benchmark = triton .testing .perf_report (
243+ triton .testing .Benchmark (
244+ x_names = ["hidden_size" , "batch_size" , "col_major" , "group_shape" ],
245+ x_vals = configs ,
246+ line_arg = "provider" ,
247+ line_vals = ["torch" , "cuda" , "triton" ],
248+ line_names = ["Torch (Compiled)" , "CUDA" , "Triton" ],
249+ styles = [("blue" , "-" ), ("green" , "-" ), ("black" , "-" )],
250+ ylabel = "us" ,
251+ plot_name = "QuantFP8 performance" ,
252+ args = {},
253+ )
254+ )(benchmark_quantization )
255+
256+ df = benchmark .run (print_data = True , dtype = dtype , return_df = True )
257+
258+ # Print geomean speedups
259+ geo_table_grouped = compute_geomean_speedups (
260+ df ,
261+ baseline_col = "Torch (Compiled)" ,
262+ speedup_cols = ["CUDA" , "Triton" ],
263+ groupby_cols = ["col_major" , "group_shape" ],
264+ )
265+
266+ print ("Speedup over Torch (Compiled)" )
267+ print (geo_table_grouped .to_string (index = False ))
0 commit comments