47
47
import pandas as pd
48
48
import sympy
49
49
import torch
50
+ import torch .nn as nn
50
51
import torch .utils .benchmark as benchmark
51
52
import tqdm
52
53
from torch .profiler import ProfilerActivity , profile
57
58
)
58
59
59
60
from torchao .float8 import (
61
+ Float8LinearConfig ,
60
62
convert_to_float8_training ,
61
63
)
64
+ from torchao .prototype .mx_formats .config import MXLinearConfig
65
+ from torchao .prototype .mx_formats .mx_linear import swap_linear_with_mx_linear
62
66
from torchao .testing .float8 .roofline_utils import (
63
67
get_float8_mem_sympy ,
64
68
get_gemm_time_sympy ,
@@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
93
97
return measurement .mean
94
98
95
99
96
- def get_gpu_kernel_time (m , x ):
100
+ def get_gpu_kernel_time (m , x , grad_output ):
97
101
# warm up
98
102
for _ in range (2 ):
99
- m (x ).sum ().backward ()
103
+ y = m (x )
104
+ y .backward (grad_output )
100
105
101
106
# capture a profiling run
102
107
activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ]
103
108
n_iter = 5
104
109
with profile (activities = activities ) as prof :
105
110
for _ in range (n_iter ):
106
- m (x ).sum ().backward ()
111
+ y = m (x )
112
+ y .backward (grad_output )
107
113
torch .cuda .synchronize ()
108
114
# get the gpu kernel time and aggregate it
109
115
num_leaf_tensors = 1 + len (list (m .parameters ()))
@@ -114,7 +120,20 @@ def get_gpu_kernel_time(m, x):
114
120
return total_time_s
115
121
116
122
117
- def get_gemm_times (M , K , N , fast_accum , cache_filename = None ):
123
+ def get_gemm_times (
124
+ M ,
125
+ K ,
126
+ N ,
127
+ fast_accum ,
128
+ bf16_memory_formats ,
129
+ cache_filename = None ,
130
+ ):
131
+ assert bf16_memory_formats in (
132
+ "row_major:col_major" ,
133
+ "row_major:row_major" ,
134
+ "col_major:row_major" ,
135
+ ), "unsupported"
136
+
118
137
# Note: this is definitely not the best way to build a cache,
119
138
# but it will do for now.
120
139
if cache_filename is not None :
@@ -127,15 +146,24 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
127
146
cache = dict ()
128
147
else :
129
148
cache = dict ()
130
- key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
149
+ key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } , { bf16_memory_formats } "
131
150
if key in cache :
132
151
return cache [key ]
133
152
134
153
device = torch .device ("cuda" )
135
154
136
155
# bf16 time
137
156
x_bf16 = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
138
- w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device ).t ().contiguous ().t ()
157
+ # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t()
158
+ w_bf16 = torch .randn (K , N , dtype = torch .bfloat16 , device = device )
159
+
160
+ if bf16_memory_formats == "row_major:col_major" :
161
+ w_bf16 = w_bf16 .t ().contiguous ().t ()
162
+ elif bf16_memory_formats == "col_major:row_major" :
163
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
164
+ elif bf16_memory_formats == "col_major:row_major" :
165
+ x_bf16 = x_bf16 .t ().contiguous ().t ()
166
+
139
167
bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
140
168
141
169
# f8 time
@@ -164,33 +192,52 @@ def do_matmul(A, B):
164
192
def run (
165
193
outfile : str ,
166
194
do_benchmarks : bool = True ,
167
- shape_gen_name : str = "square " ,
195
+ shape_gen_name : str = "pow2 " ,
168
196
gemm_cache_filename : Optional [str ] = None ,
169
197
n_limit : Optional [int ] = None ,
198
+ float8_recipe_name : Optional [str ] = None ,
199
+ mx_recipe_name : Optional [str ] = None ,
200
+ enable_fusion_modeling : bool = False ,
170
201
):
171
202
"""
172
203
Args:
173
204
* `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked
174
- * `shape_gen_name`: `llama`, `square `, or `sweep`
205
+ * `shape_gen_name`: `llama`, `pow2`, `pow2_extended `, or `sweep`
175
206
* `gemm_cache_filename (optional)`: file to cache gemm benchmark results
176
207
* `n_limit (optional)`: if specified, only runs `n_limit` iterations
208
+ * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead
177
209
"""
178
210
211
+ assert not (
212
+ (float8_recipe_name is not None ) and (mx_recipe_name is not None )
213
+ ), "unsupported"
214
+ if float8_recipe_name is None and mx_recipe_name is None :
215
+ float8_recipe_name = "tensorwise"
216
+
217
+ print (f"GPU: { torch .cuda .get_device_name (0 )} " )
179
218
print (f"do_benchmarks: { do_benchmarks } " )
180
219
print (f"shape_gen_name: { shape_gen_name } " )
220
+ print (f"float8_recipe_name: { float8_recipe_name } " )
221
+ print (f"mx_recipe_name: { mx_recipe_name } " )
222
+ print (f"enable_fusion_modeling: { enable_fusion_modeling } " )
181
223
182
224
M , K , N = sympy .symbols ("M K N" )
183
225
184
- fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy (
226
+ fp8_ovhd_time_sympy = get_float8_mem_sympy (
185
227
M ,
186
228
K ,
187
229
N ,
230
+ float8_recipe_name ,
231
+ mx_recipe_name ,
232
+ enable_fusion_modeling ,
233
+ )
234
+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
235
+ fp8_gemm_time_sympy = get_gemm_time_sympy (
236
+ M , K , N , torch .float8_e4m3fn , mx_recipe_name
188
237
)
189
-
190
- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 )
191
238
print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
192
- fp8_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .float8_e4m3fn )
193
239
print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
240
+ print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
194
241
print ()
195
242
196
243
headers = [
@@ -217,6 +264,9 @@ def run(
217
264
# the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple
218
265
# we don't break them out and don't have a roofline for them.
219
266
"b_fp8_e2e_spdp" ,
267
+ # how well benchmarked gemms match roofline predicted gemms
268
+ "rb_bf16_gemm_ratio" ,
269
+ "rb_fp8_gemm_ratio" ,
220
270
]
221
271
results = []
222
272
@@ -237,43 +287,72 @@ def run(
237
287
238
288
# if enabled, also measured observed gemm time
239
289
b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
290
+ rb_bf16_gemm_ratio = - 1
291
+ rb_fp8_gemm_ratio = - 1
292
+
240
293
if do_benchmarks :
294
+ # TODO(future): make the bf16 gemm times exactly match the e2e
295
+ # benchmarks, there is a slight deviation, probably related to gemm
296
+ # operand memory formats/transpositions below not exactly matching
297
+ # what PyTorch core is doing for `torch.mm`
298
+ # input @ weight_t = output
241
299
bf16_g1 , f8_g1 = get_gemm_times (
242
- M_val , K_val , N_val , True , gemm_cache_filename
300
+ M_val , K_val , N_val , True , "row_major:col_major" , gemm_cache_filename
243
301
)
302
+ # grad_output @ weight = grad_input
244
303
bf16_g2 , f8_g2 = get_gemm_times (
245
- M_val , N_val , K_val , False , gemm_cache_filename
304
+ M_val , N_val , K_val , False , "row_major:row_major" , gemm_cache_filename
246
305
)
306
+ # input_t @ grad_output = grad_weight
247
307
bf16_g3 , f8_g3 = get_gemm_times (
248
- K_val , M_val , N_val , False , gemm_cache_filename
308
+ K_val , M_val , N_val , False , "col_major:row_major" , gemm_cache_filename
249
309
)
250
310
b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3
251
311
b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
312
+ rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s
313
+ rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s
252
314
253
315
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
254
316
r_fp8_ovhd_time_s = float (
255
- fp8_mem_time_sympy_dyn_nolimit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
317
+ fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
256
318
)
257
319
258
320
b_bf16_e2e_time_s , b_fp8_e2e_time_s = 0 , 0
259
321
if do_benchmarks :
260
322
# create the model
261
- m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
323
+ if enable_fusion_modeling :
324
+ m_orig = LNLinearSigmoid (K_val , N_val ).cuda ().bfloat16 ()
325
+ else :
326
+ m_orig = (
327
+ nn .Sequential (nn .Linear (K_val , N_val , bias = False )).cuda ().bfloat16 ()
328
+ )
262
329
x = torch .randn (
263
330
M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
264
331
).requires_grad_ ()
265
332
333
+ # get the gradient of the right shape
334
+ grad_output = torch .randn (N_val , K_val , dtype = torch .bfloat16 , device = "cuda" )
335
+
266
336
# get the bf16 gpu kernel time
267
337
torch ._dynamo .reset ()
268
338
m_bf16 = torch .compile (copy .deepcopy (m_orig ))
269
- b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x )
339
+ b_bf16_e2e_time_s = get_gpu_kernel_time (m_bf16 , x , grad_output )
270
340
271
341
# get the float8 dynamic scaling gpu kernel time
272
342
273
343
torch ._dynamo .reset ()
274
- m_fp8_dyn = convert_to_float8_training (copy .deepcopy (m_orig ))
344
+ if float8_recipe_name is not None :
345
+ config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
346
+ m_fp8_dyn = convert_to_float8_training (
347
+ copy .deepcopy (m_orig ), config = config
348
+ )
349
+ else :
350
+ assert mx_recipe_name is not None
351
+ config = MXLinearConfig .from_recipe_name (mx_recipe_name )
352
+ m_fp8_dyn = copy .deepcopy (m_orig )
353
+ swap_linear_with_mx_linear (m_fp8_dyn , config = config )
275
354
m_fp8_dyn = torch .compile (m_fp8_dyn )
276
- b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x )
355
+ b_fp8_e2e_time_s = get_gpu_kernel_time (m_fp8_dyn , x , grad_output )
277
356
278
357
results .append (
279
358
[
@@ -295,6 +374,9 @@ def run(
295
374
b_bf16_e2e_time_s ,
296
375
b_fp8_e2e_time_s ,
297
376
b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20 ),
377
+ # gemm ratios
378
+ rb_bf16_gemm_ratio ,
379
+ rb_fp8_gemm_ratio ,
298
380
]
299
381
)
300
382
0 commit comments