47
47
import pandas as pd
48
48
import sympy
49
49
import torch
50
+ import torch .nn as nn
50
51
import torch .utils .benchmark as benchmark
51
52
import tqdm
52
53
from torch .profiler import ProfilerActivity , profile
57
58
)
58
59
59
60
from torchao .float8 import (
61
+ Float8LinearConfig ,
60
62
convert_to_float8_training ,
61
63
)
64
+ from torchao .prototype .mx_formats .config import MXLinearConfig
65
+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
62
66
from torchao .testing .float8 .roofline_utils import (
63
67
get_float8_mem_sympy ,
64
68
get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
93
97
return measurement .mean
94
98
95
99
96
- def get_gpu_kernel_time (m , x ):
100
+ def get_gpu_kernel_time (m , x , grad_output ):
97
101
# warm up
98
102
for _ in range (2 ):
99
- m (x ).sum ().backward ()
103
+ y = m (x )
104
+ y .backward (grad_output )
100
105
101
106
# capture a profiling run
102
107
activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103
108
n_iter = 5
104
109
with profile (activities = activities ) as prof :
105
110
for _ in range (n_iter ):
106
- m (x ).sum ().backward ()
111
+ y = m (x )
112
+ y .backward (grad_output )
107
113
torch .cuda .synchronize ()
108
114
# get the gpu kernel time and aggregate it
109
115
num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,10 +120,28 @@ def get_gpu_kernel_time(m, x):
114
120
return total_time_s
115
121
116
122
117
- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123
+ def get_gemm_times (
124
+ gemm_role : str ,
125
+ M : int ,
126
+ K : int ,
127
+ N : int ,
128
+ fast_accum : bool ,
129
+ bf16_memory_formats : str ,
130
+ float8_recipe_name : Optional [str ],
131
+ mx_recipe_name : Optional [str ],
132
+ cache_filename = None ,
133
+ ):
134
+ assert gemm_role in ("output" , "grad_input" , "grad_weight" ), "unsupported"
135
+ assert bf16_memory_formats in (
136
+ "row_major:col_major" ,
137
+ "row_major:row_major" ,
138
+ "col_major:row_major" ,
139
+ ), "unsupported"
140
+
118
141
# Note: this is definitely not the best way to build a cache,
119
142
# but it will do for now.
120
143
if cache_filename is not None :
144
+ assert False , "TODO retest this for new arguments"
121
145
if os .path .isfile (cache_filename ):
122
146
# cache already exists, use it
123
147
with open (cache_filename , "r" ) as f :
@@ -127,30 +151,48 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127
151
cache = dict ()
128
152
else :
129
153
cache = dict ()
130
- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
154
+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131
155
if key in cache :
132
156
return cache [key ]
133
157
134
158
device = torch .device ("cuda" )
135
159
136
160
# bf16 time
137
161
x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138
- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
162
+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
163
+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
164
+
165
+ if bf16_memory_formats == "row_major:col_major" :
166
+ w_bf16 = w_bf16 .t ().contiguous ().t ()
167
+ elif bf16_memory_formats == "col_major:row_major" :
168
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
169
+ elif bf16_memory_formats == "col_major:row_major" :
170
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
171
+
139
172
bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140
173
141
174
# f8 time
142
- d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
143
- A = torch .zeros (M , K , device = device , dtype = d1 )
144
- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
145
- scale_a = torch .tensor ([1.0 ], device = device )
146
- scale_b = torch .tensor ([1.0 ], device = device )
147
-
148
- def do_matmul (A , B ):
149
- return torch ._scaled_mm (
150
- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
151
- )
175
+ if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight" :
176
+ f8_time_s = bf16_time_s
177
+ else :
178
+ d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
179
+ A = torch .zeros (M , K , device = device , dtype = d1 )
180
+ B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
181
+ if float8_recipe_name == "tensorwise" :
182
+ scale_a = torch .tensor ([1.0 ], device = device )
183
+ scale_b = torch .tensor ([1.0 ], device = device )
184
+ elif float8_recipe_name in ("rowwise" , "rowwise_with_gw_hp" ):
185
+ scale_a = torch .ones (M , 1 , device = device )
186
+ scale_b = torch .ones (1 , N , device = device )
187
+ else :
188
+ assert False , "TODO add mx gemm here"
189
+
190
+ def do_matmul (A , B ):
191
+ return torch ._scaled_mm (
192
+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
193
+ )
152
194
153
- f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
195
+ f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
154
196
155
197
# save to cache if needed
156
198
if cache_filename is not None :
@@ -164,33 +206,52 @@ def do_matmul(A, B):
164
206
def run (
165
207
outfile : str ,
166
208
do_benchmarks : bool = True ,
167
- shape_gen_name : str = "square " ,
209
+ shape_gen_name : str = "pow2 " ,
168
210
gemm_cache_filename : Optional [str ] = None ,
169
211
n_limit : Optional [int ] = None ,
212
+ float8_recipe_name : Optional [str ] = None ,
213
+ mx_recipe_name : Optional [str ] = None ,
214
+ enable_fusion_modeling : bool = False ,
170
215
):
171
216
"""
172
217
Args:
173
218
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174
- * `shape_gen_name`: `llama`, `square `, or `sweep`
219
+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175
220
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176
221
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
222
+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177
223
"""
178
224
225
+ assert not (
226
+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
227
+ ), "unsupported"
228
+ if float8_recipe_name is None and mx_recipe_name is None :
229
+ float8_recipe_name = "tensorwise"
230
+
231
+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179
232
print (f"do_benchmarks: { do_benchmarks } " )
180
233
print (f"shape_gen_name: { shape_gen_name } " )
234
+ print (f"float8_recipe_name: { float8_recipe_name } " )
235
+ print (f"mx_recipe_name: { mx_recipe_name } " )
236
+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181
237
182
238
M , K , N = sympy .symbols ("M K N" )
183
239
184
- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
240
+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185
241
M ,
186
242
K ,
187
243
N ,
244
+ float8_recipe_name ,
245
+ mx_recipe_name ,
246
+ enable_fusion_modeling ,
247
+ )
248
+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None , None )
249
+ fp8_gemm_time_sympy = get_gemm_time_sympy (
250
+ M , K , N , torch .float8_e4m3fn , float8_recipe_name , mx_recipe_name
188
251
)
189
-
190
- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191
252
print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192
- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193
253
print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
254
+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194
255
print ()
195
256
196
257
headers = [
@@ -217,6 +278,9 @@ def run(
217
278
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218
279
# we don't break them out and don't have a roofline for them.
219
280
"b_fp8_e2e_spdp" ,
281
+ # how well benchmarked gemms match roofline predicted gemms
282
+ "rb_bf16_gemm_ratio" ,
283
+ "rb_fp8_gemm_ratio" ,
220
284
]
221
285
results = []
222
286
@@ -237,43 +301,96 @@ def run(
237
301
238
302
# if enabled, also measured observed gemm time
239
303
b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
304
+ rb_bf16_gemm_ratio = - 1
305
+ rb_fp8_gemm_ratio = - 1
306
+
240
307
if do_benchmarks :
308
+ # TODO(future): make the bf16 gemm times exactly match the e2e
309
+ # benchmarks, there is a slight deviation, probably related to gemm
310
+ # operand memory formats/transpositions below not exactly matching
311
+ # what PyTorch core is doing for `torch.mm`
312
+ # input @ weight_t = output
241
313
bf16_g1 , f8_g1 = get_gemm_times (
242
- M_val , K_val , N_val , True , gemm_cache_filename
314
+ "output" ,
315
+ M_val ,
316
+ K_val ,
317
+ N_val ,
318
+ True ,
319
+ "row_major:col_major" ,
320
+ float8_recipe_name ,
321
+ mx_recipe_name ,
322
+ gemm_cache_filename ,
243
323
)
324
+ # grad_output @ weight = grad_input
244
325
bf16_g2 , f8_g2 = get_gemm_times (
245
- M_val , N_val , K_val , False , gemm_cache_filename
326
+ "grad_input" ,
327
+ M_val ,
328
+ N_val ,
329
+ K_val ,
330
+ False ,
331
+ "row_major:row_major" ,
332
+ float8_recipe_name ,
333
+ mx_recipe_name ,
334
+ gemm_cache_filename ,
246
335
)
336
+ # input_t @ grad_output = grad_weight
247
337
bf16_g3 , f8_g3 = get_gemm_times (
248
- K_val , M_val , N_val , False , gemm_cache_filename
338
+ "grad_weight" ,
339
+ K_val ,
340
+ M_val ,
341
+ N_val ,
342
+ False ,
343
+ "col_major:row_major" ,
344
+ float8_recipe_name ,
345
+ mx_recipe_name ,
346
+ gemm_cache_filename ,
249
347
)
250
348
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251
349
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
350
+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
351
+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252
352
253
353
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254
354
r_fp8_ovhd_time_s = float (
255
- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
355
+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256
356
)
257
357
258
358
b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259
359
if do_benchmarks :
260
360
# create the model
261
- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
361
+ if enable_fusion_modeling :
362
+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
363
+ else :
364
+ m_orig = (
365
+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
366
+ )
262
367
x = torch .randn (
263
368
M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264
369
).requires_grad_ ()
265
370
371
+ # get the gradient of the right shape
372
+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
373
+
266
374
# get the bf16 gpu kernel time
267
375
torch ._dynamo .reset ()
268
376
m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269
- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
377
+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270
378
271
379
# get the float8 dynamic scaling gpu kernel time
272
380
273
381
torch ._dynamo .reset ()
274
- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
382
+ if float8_recipe_name is not None :
383
+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
384
+ m_fp8_dyn = convert_to_float8_training (
385
+ copy .deepcopy (m_orig ), config = config
386
+ )
387
+ else :
388
+ assert mx_recipe_name is not None
389
+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
390
+ m_fp8_dyn = copy .deepcopy (m_orig )
391
+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275
392
m_fp8_dyn = torch .compile (m_fp8_dyn )
276
- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
393
+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277
394
278
395
results .append (
279
396
[
@@ -295,6 +412,9 @@ def run(
295
412
b_bf16_e2e_time_s ,
296
413
b_fp8_e2e_time_s ,
297
414
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
415
+ # gemm ratios
416
+ rb_bf16_gemm_ratio ,
417
+ rb_fp8_gemm_ratio ,
298
418
]
299
419
)
300
420
0 commit comments