47
47
import pandas as pd
48
48
import sympy
49
49
import torch
50
+ import torch .nn as nn
50
51
import torch .utils .benchmark as benchmark
51
52
import tqdm
52
53
from torch .profiler import ProfilerActivity , profile
57
58
)
58
59
59
60
from torchao .float8 import (
61
+ Float8LinearConfig ,
60
62
convert_to_float8_training ,
61
63
)
64
+ from torchao .prototype .mx_formats .config import MXLinearConfig
65
+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
62
66
from torchao .testing .float8 .roofline_utils import (
63
67
get_float8_mem_sympy ,
64
68
get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
93
97
return measurement .mean
94
98
95
99
96
- def get_gpu_kernel_time (m , x ):
100
+ def get_gpu_kernel_time (m , x , grad_output ):
97
101
# warm up
98
102
for _ in range (2 ):
99
- m (x ).sum ().backward ()
103
+ y = m (x )
104
+ y .backward (grad_output )
100
105
101
106
# capture a profiling run
102
107
activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103
108
n_iter = 5
104
109
with profile (activities = activities ) as prof :
105
110
for _ in range (n_iter ):
106
- m (x ).sum ().backward ()
111
+ y = m (x )
112
+ y .backward (grad_output )
107
113
torch .cuda .synchronize ()
108
114
# get the gpu kernel time and aggregate it
109
115
num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,7 +120,22 @@ def get_gpu_kernel_time(m, x):
114
120
return total_time_s
115
121
116
122
117
- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123
+ def get_gemm_times (
124
+ M ,
125
+ K ,
126
+ N ,
127
+ fast_accum ,
128
+ bf16_memory_formats ,
129
+ float8_recipe_name ,
130
+ mx_recipe_name ,
131
+ cache_filename = None ,
132
+ ):
133
+ assert bf16_memory_formats in (
134
+ "row_major:col_major" ,
135
+ "row_major:row_major" ,
136
+ "col_major:row_major" ,
137
+ ), "unsupported"
138
+
118
139
# Note: this is definitely not the best way to build a cache,
119
140
# but it will do for now.
120
141
if cache_filename is not None :
@@ -127,23 +148,38 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127
148
cache = dict ()
128
149
else :
129
150
cache = dict ()
130
- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
151
+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131
152
if key in cache :
132
153
return cache [key ]
133
154
134
155
device = torch .device ("cuda" )
135
156
136
157
# bf16 time
137
158
x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138
- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
159
+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
160
+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
161
+
162
+ if bf16_memory_formats == "row_major:col_major" :
163
+ w_bf16 = w_bf16 .t ().contiguous ().t ()
164
+ elif bf16_memory_formats == "col_major:row_major" :
165
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
166
+ elif bf16_memory_formats == "col_major:row_major" :
167
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
168
+
139
169
bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140
170
141
171
# f8 time
142
172
d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
143
173
A = torch .zeros (M , K , device = device , dtype = d1 )
144
174
B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
145
- scale_a = torch .tensor ([1.0 ], device = device )
146
- scale_b = torch .tensor ([1.0 ], device = device )
175
+ if float8_recipe_name == "tensorwise" :
176
+ scale_a = torch .tensor ([1.0 ], device = device )
177
+ scale_b = torch .tensor ([1.0 ], device = device )
178
+ elif float8_recipe_name == "rowwise" :
179
+ scale_a = torch .ones (M , 1 , device = device )
180
+ scale_b = torch .ones (1 , N , device = device )
181
+ else :
182
+ assert False , "TODO add mx gemm here"
147
183
148
184
def do_matmul (A , B ):
149
185
return torch ._scaled_mm (
@@ -164,33 +200,52 @@ def do_matmul(A, B):
164
200
def run (
165
201
outfile : str ,
166
202
do_benchmarks : bool = True ,
167
- shape_gen_name : str = "square " ,
203
+ shape_gen_name : str = "pow2 " ,
168
204
gemm_cache_filename : Optional [str ] = None ,
169
205
n_limit : Optional [int ] = None ,
206
+ float8_recipe_name : Optional [str ] = None ,
207
+ mx_recipe_name : Optional [str ] = None ,
208
+ enable_fusion_modeling : bool = False ,
170
209
):
171
210
"""
172
211
Args:
173
212
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174
- * `shape_gen_name`: `llama`, `square `, or `sweep`
213
+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175
214
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176
215
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
216
+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177
217
"""
178
218
219
+ assert not (
220
+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
221
+ ), "unsupported"
222
+ if float8_recipe_name is None and mx_recipe_name is None :
223
+ float8_recipe_name = "tensorwise"
224
+
225
+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179
226
print (f"do_benchmarks: { do_benchmarks } " )
180
227
print (f"shape_gen_name: { shape_gen_name } " )
228
+ print (f"float8_recipe_name: { float8_recipe_name } " )
229
+ print (f"mx_recipe_name: { mx_recipe_name } " )
230
+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181
231
182
232
M , K , N = sympy .symbols ("M K N" )
183
233
184
- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
234
+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185
235
M ,
186
236
K ,
187
237
N ,
238
+ float8_recipe_name ,
239
+ mx_recipe_name ,
240
+ enable_fusion_modeling ,
241
+ )
242
+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
243
+ fp8_gemm_time_sympy = get_gemm_time_sympy (
244
+ M , K , N , torch .float8_e4m3fn , mx_recipe_name
188
245
)
189
-
190
- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191
246
print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192
- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193
247
print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
248
+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194
249
print ()
195
250
196
251
headers = [
@@ -217,6 +272,9 @@ def run(
217
272
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218
273
# we don't break them out and don't have a roofline for them.
219
274
"b_fp8_e2e_spdp" ,
275
+ # how well benchmarked gemms match roofline predicted gemms
276
+ "rb_bf16_gemm_ratio" ,
277
+ "rb_fp8_gemm_ratio" ,
220
278
]
221
279
results = []
222
280
@@ -237,43 +295,93 @@ def run(
237
295
238
296
# if enabled, also measured observed gemm time
239
297
b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
298
+ rb_bf16_gemm_ratio = - 1
299
+ rb_fp8_gemm_ratio = - 1
300
+
240
301
if do_benchmarks :
302
+ # TODO(future): make the bf16 gemm times exactly match the e2e
303
+ # benchmarks, there is a slight deviation, probably related to gemm
304
+ # operand memory formats/transpositions below not exactly matching
305
+ # what PyTorch core is doing for `torch.mm`
306
+ # input @ weight_t = output
241
307
bf16_g1 , f8_g1 = get_gemm_times (
242
- M_val , K_val , N_val , True , gemm_cache_filename
308
+ M_val ,
309
+ K_val ,
310
+ N_val ,
311
+ True ,
312
+ "row_major:col_major" ,
313
+ float8_recipe_name ,
314
+ mx_recipe_name ,
315
+ gemm_cache_filename ,
243
316
)
317
+ # grad_output @ weight = grad_input
244
318
bf16_g2 , f8_g2 = get_gemm_times (
245
- M_val , N_val , K_val , False , gemm_cache_filename
319
+ M_val ,
320
+ N_val ,
321
+ K_val ,
322
+ False ,
323
+ "row_major:row_major" ,
324
+ float8_recipe_name ,
325
+ mx_recipe_name ,
326
+ gemm_cache_filename ,
246
327
)
328
+ # input_t @ grad_output = grad_weight
247
329
bf16_g3 , f8_g3 = get_gemm_times (
248
- K_val , M_val , N_val , False , gemm_cache_filename
330
+ K_val ,
331
+ M_val ,
332
+ N_val ,
333
+ False ,
334
+ "col_major:row_major" ,
335
+ float8_recipe_name ,
336
+ mx_recipe_name ,
337
+ gemm_cache_filename ,
249
338
)
250
339
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251
340
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
341
+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
342
+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252
343
253
344
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254
345
r_fp8_ovhd_time_s = float (
255
- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
346
+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256
347
)
257
348
258
349
b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259
350
if do_benchmarks :
260
351
# create the model
261
- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
352
+ if enable_fusion_modeling :
353
+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
354
+ else :
355
+ m_orig = (
356
+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
357
+ )
262
358
x = torch .randn (
263
359
M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264
360
).requires_grad_ ()
265
361
362
+ # get the gradient of the right shape
363
+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
364
+
266
365
# get the bf16 gpu kernel time
267
366
torch ._dynamo .reset ()
268
367
m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269
- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
368
+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270
369
271
370
# get the float8 dynamic scaling gpu kernel time
272
371
273
372
torch ._dynamo .reset ()
274
- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
373
+ if float8_recipe_name is not None :
374
+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
375
+ m_fp8_dyn = convert_to_float8_training (
376
+ copy .deepcopy (m_orig ), config = config
377
+ )
378
+ else :
379
+ assert mx_recipe_name is not None
380
+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
381
+ m_fp8_dyn = copy .deepcopy (m_orig )
382
+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275
383
m_fp8_dyn = torch .compile (m_fp8_dyn )
276
- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
384
+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277
385
278
386
results .append (
279
387
[
@@ -295,6 +403,9 @@ def run(
295
403
b_bf16_e2e_time_s ,
296
404
b_fp8_e2e_time_s ,
297
405
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
406
+ # gemm ratios
407
+ rb_bf16_gemm_ratio ,
408
+ rb_fp8_gemm_ratio ,
298
409
]
299
410
)
300
411
0 commit comments