@@ -121,15 +121,17 @@ def get_gpu_kernel_time(m, x, grad_output):
121
121
122
122
123
123
def get_gemm_times (
124
- M ,
125
- K ,
126
- N ,
127
- fast_accum ,
128
- bf16_memory_formats ,
129
- float8_recipe_name ,
130
- mx_recipe_name ,
124
+ gemm_role : str ,
125
+ M : int ,
126
+ K : int ,
127
+ N : int ,
128
+ fast_accum : bool ,
129
+ bf16_memory_formats : str ,
130
+ float8_recipe_name : Optional [str ],
131
+ mx_recipe_name : Optional [str ],
131
132
cache_filename = None ,
132
133
):
134
+ assert gemm_role in ("output" , "grad_input" , "grad_weight" ), "unsupported"
133
135
assert bf16_memory_formats in (
134
136
"row_major:col_major" ,
135
137
"row_major:row_major" ,
@@ -139,6 +141,7 @@ def get_gemm_times(
139
141
# Note: this is definitely not the best way to build a cache,
140
142
# but it will do for now.
141
143
if cache_filename is not None :
144
+ assert False , "TODO retest this for new arguments"
142
145
if os .path .isfile (cache_filename ):
143
146
# cache already exists, use it
144
147
with open (cache_filename , "r" ) as f :
@@ -169,24 +172,27 @@ def get_gemm_times(
169
172
bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
170
173
171
174
# f8 time
172
- d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
173
- A = torch .zeros (M , K , device = device , dtype = d1 )
174
- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
175
- if float8_recipe_name == "tensorwise" :
176
- scale_a = torch .tensor ([1.0 ], device = device )
177
- scale_b = torch .tensor ([1.0 ], device = device )
178
- elif float8_recipe_name == "rowwise" :
179
- scale_a = torch .ones (M , 1 , device = device )
180
- scale_b = torch .ones (1 , N , device = device )
175
+ if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight" :
176
+ f8_time_s = bf16_time_s
181
177
else :
182
- assert False , "TODO add mx gemm here"
178
+ d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , torch .bfloat16
179
+ A = torch .zeros (M , K , device = device , dtype = d1 )
180
+ B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
181
+ if float8_recipe_name == "tensorwise" :
182
+ scale_a = torch .tensor ([1.0 ], device = device )
183
+ scale_b = torch .tensor ([1.0 ], device = device )
184
+ elif float8_recipe_name in ("rowwise" , "rowwise_with_gw_hp" ):
185
+ scale_a = torch .ones (M , 1 , device = device )
186
+ scale_b = torch .ones (1 , N , device = device )
187
+ else :
188
+ assert False , "TODO add mx gemm here"
183
189
184
- def do_matmul (A , B ):
185
- return torch ._scaled_mm (
186
- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
187
- )
190
+ def do_matmul (A , B ):
191
+ return torch ._scaled_mm (
192
+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
193
+ )
188
194
189
- f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
195
+ f8_time_s = get_gpu_kernel_gemm_time_s (do_matmul , A , B )
190
196
191
197
# save to cache if needed
192
198
if cache_filename is not None :
@@ -239,9 +245,9 @@ def run(
239
245
mx_recipe_name ,
240
246
enable_fusion_modeling ,
241
247
)
242
- bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
248
+ bf16_gemm_time_sympy = get_gemm_time_sympy (M , K , N , torch .bfloat16 , None , None )
243
249
fp8_gemm_time_sympy = get_gemm_time_sympy (
244
- M , K , N , torch .float8_e4m3fn , mx_recipe_name
250
+ M , K , N , torch .float8_e4m3fn , float8_recipe_name , mx_recipe_name
245
251
)
246
252
print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
247
253
print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
@@ -305,6 +311,7 @@ def run(
305
311
# what PyTorch core is doing for `torch.mm`
306
312
# input @ weight_t = output
307
313
bf16_g1 , f8_g1 = get_gemm_times (
314
+ "output" ,
308
315
M_val ,
309
316
K_val ,
310
317
N_val ,
@@ -316,6 +323,7 @@ def run(
316
323
)
317
324
# grad_output @ weight = grad_input
318
325
bf16_g2 , f8_g2 = get_gemm_times (
326
+ "grad_input" ,
319
327
M_val ,
320
328
N_val ,
321
329
K_val ,
@@ -327,6 +335,7 @@ def run(
327
335
)
328
336
# input_t @ grad_output = grad_weight
329
337
bf16_g3 , f8_g3 = get_gemm_times (
338
+ "grad_weight" ,
330
339
K_val ,
331
340
M_val ,
332
341
N_val ,
0 commit comments