70
70
ScalingType ,
71
71
CastConfig ,
72
72
)
73
+ from torchao .float8 .config import recipe_name_to_linear_config , Float8LinearRecipeName
73
74
74
75
75
76
class LNLinearSigmoid (torch .nn .Module ):
@@ -129,6 +130,8 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None):
129
130
else :
130
131
# cache does not exist yet, create it
131
132
cache = dict ()
133
+ else :
134
+ cache = dict ()
132
135
key = f"{ M } ,{ K } ,{ N } ,{ fast_accum } "
133
136
if key in cache :
134
137
return cache [key ]
@@ -153,13 +156,18 @@ def do_matmul(A, B):
153
156
)
154
157
f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
155
158
159
+ scale_a = torch .ones (M , 1 , device = device )
160
+ scale_b = torch .ones (1 , N , device = device )
161
+ fast_accum = True # for axiswise
162
+ f8_axs_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
163
+
156
164
# save to cache if needed
157
165
if cache_filename is not None :
158
- cache [key ] = [bf16_time_s , f8_time_s ]
166
+ cache [key ] = [bf16_time_s , f8_time_s , f8_axs_time_s ]
159
167
with open (cache_filename , 'w' ) as f :
160
168
json .dump (cache , f )
161
169
162
- return bf16_time_s , f8_time_s
170
+ return bf16_time_s , f8_time_s , f8_axs_time_s
163
171
164
172
def run (
165
173
outfile : str ,
@@ -231,13 +239,15 @@ def run(
231
239
headers = [
232
240
'fwd_M' , 'fwd_K' , 'fwd_N' ,
233
241
# gemm microbenchmarks
234
- 'bf16_gemm_s' , 'fp8_gemm_s' ,
242
+ 'bf16_gemm_s' , 'fp8_gemm_s' , 'fp8_axs_gemm_time_s' ,
235
243
# roofline memory overhead estimates
236
244
'fp8_oh_dyn_limit' , 'fp8_oh_dyn_nolimit' ,
237
245
'fp8_oh_del_limit' , 'fp8_oh_del_nolimit' ,
238
246
# actual e2e measurements
239
- 'bf16_e2e_s' , 'fp8_dyn_e2e_s' , 'fp8_del_e2e_s' ,
240
- 'fp8_dyn_speedup' , 'fp8_del_speedup' ,
247
+ 'bf16_s' , 'fp8_dyn_s' , 'fp8_del_s' , 'fp8_dyn_axs_s' ,
248
+ # 'fp8_lw_s',
249
+ 'fp8_dyn_sp' , 'fp8_del_sp' , 'fp8_dyn_axs_sp' ,
250
+ # 'fp8_lw_sp',
241
251
]
242
252
results = []
243
253
@@ -248,15 +258,18 @@ def run(
248
258
break
249
259
250
260
if gemm_time_strategy == "benchmarks" :
251
- bf16_g1 , f8_g1 = get_gemm_times (M_val , K_val , N_val , True , gemm_cache_filename )
252
- bf16_g2 , f8_g2 = get_gemm_times (M_val , N_val , K_val , False , gemm_cache_filename )
253
- bf16_g3 , f8_g3 = get_gemm_times (K_val , M_val , N_val , False , gemm_cache_filename )
261
+ bf16_g1 , f8_g1 , f8_g1_axs = get_gemm_times (M_val , K_val , N_val , True , gemm_cache_filename )
262
+ bf16_g2 , f8_g2 , f8_g2_axs = get_gemm_times (M_val , N_val , K_val , False , gemm_cache_filename )
263
+ bf16_g3 , f8_g3 , f8_g3_axs = get_gemm_times (K_val , M_val , N_val , False , gemm_cache_filename )
254
264
bf16_time_val = bf16_g1 + bf16_g2 + bf16_g3
255
265
fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3
266
+ fp8_axs_gemm_time_s = f8_g1_axs + f8_g2_axs + f8_g3_axs
256
267
else :
257
268
assert gemm_time_strategy == "roofline" , "unsupported"
258
269
bf16_time_val = bf16_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
259
270
fp8_gemm_time_s = fp8_gemm_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
271
+ # for now, assume axiswise gemm is similar to tensorwise
272
+ fp8_axs_gemm_time_s = fp8_gemm_time_s
260
273
261
274
fp8_mem_time_dyn_limit_s = \
262
275
fp8_mem_time_sympy_dyn_limit .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
@@ -291,23 +304,43 @@ def run(
291
304
cast_config_weight = CastConfig (scaling_type = ScalingType .DELAYED ),
292
305
cast_config_grad_output = CastConfig (scaling_type = ScalingType .DELAYED ),
293
306
)
294
- m_fp8_del = convert_to_float8_training (m_orig )
307
+ m_fp8_del = convert_to_float8_training (copy . deepcopy ( m_orig ), config = config )
295
308
m_fp8_del = torch .compile (m_fp8_del )
296
309
fp8_del_time_actual_s = get_gpu_kernel_time (m_fp8_del , x )
297
310
311
+ # get the float8 dynamic axiswise scaling gpu kernel time
312
+ torch ._dynamo .reset ()
313
+ config = recipe_name_to_linear_config (Float8LinearRecipeName .ALL_AXISWISE )
314
+ m_fp8_dyn_axs = convert_to_float8_training (copy .deepcopy (m_orig ), config = config )
315
+ m_fp8_dyn_axs = torch .compile (m_fp8_dyn_axs )
316
+ fp8_dyn_axs_time_actual_s = get_gpu_kernel_time (m_fp8_dyn_axs , x )
317
+
318
+ # get the lw recipe scaling gpu kernel time
319
+ # TODO(future PR): enable below once basic performance issues
320
+ # are fixed
321
+ # torch._dynamo.reset()
322
+ # config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
323
+ # m_fp8_lw = convert_to_float8_training(m_orig, config=config)
324
+ # m_fp8_lw = torch.compile(m_fp8_lw)
325
+ # fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
326
+
298
327
results .append ([
299
328
M_val , K_val , N_val ,
300
329
# gemm microbenchmarks
301
- bf16_time_val , fp8_gemm_time_s ,
330
+ bf16_time_val , fp8_gemm_time_s , fp8_axs_gemm_time_s ,
302
331
# roofline overhead estimates
303
332
fp8_mem_time_dyn_limit_s ,
304
333
fp8_mem_time_dyn_nolimit_s ,
305
334
fp8_mem_time_del_limit_s ,
306
335
fp8_mem_time_del_nolimit_s ,
307
336
# e2e numbers
308
337
bf16_time_actual_s , fp8_dyn_time_actual_s , fp8_del_time_actual_s ,
338
+ fp8_dyn_axs_time_actual_s ,
339
+ # fp8_lw_time_actual_s,
309
340
bf16_time_actual_s / fp8_dyn_time_actual_s ,
310
341
bf16_time_actual_s / fp8_del_time_actual_s ,
342
+ bf16_time_actual_s / fp8_dyn_axs_time_actual_s ,
343
+ # bf16_time_actual_s / fp8_lw_time_actual_s,
311
344
])
312
345
313
346
df = pd .DataFrame (results , columns = headers )
0 commit comments