@@ -306,8 +306,9 @@ def main(
306
306
"fwd" ,
307
307
"cast_only" ,
308
308
"cast_with_to_blocked" ,
309
+ "cast_only_dim0_dim1" ,
309
310
)
310
- ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
311
+ ), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1` "
311
312
if mode_filter == "cast_only" :
312
313
assert experiment_filter == "lowp" , "unsupported"
313
314
@@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
395
396
scale_blocked = to_blocked (x_mx ._scale_e8m0 .reshape (m , k // config .block_size ))
396
397
return x_mx ._data , scale_blocked
397
398
399
+ # this function is used for cast_only_dim0_dim1
400
+ def cast_only_dim0_dim1 (x_hp ):
401
+ x_hp_t_c = x_hp .t ().contiguous ()
402
+ x_mx_dim0 = MXTensor .to_mx (
403
+ x_hp ,
404
+ config .elem_dtype ,
405
+ config .block_size ,
406
+ gemm_kernel_choice = config .gemm_kernel_choice ,
407
+ )
408
+ x_mx_dim1 = MXTensor .to_mx (
409
+ x_hp_t_c ,
410
+ config .elem_dtype ,
411
+ config .block_size ,
412
+ gemm_kernel_choice = config .gemm_kernel_choice ,
413
+ )
414
+ return x_mx_dim0 , x_mx_dim1
415
+
398
416
print ("m_ref" , m_ref )
399
417
print ("m_lowp" , m_lowp )
400
418
print ("input_tensor.shape" , input_tensor .shape )
@@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
423
441
elif mode_filter == "cast_with_to_blocked" :
424
442
_input_tensor_mx , scale = cast_with_to_blocked (input_tensor )
425
443
return
444
+ elif mode_filter == "cast_only_dim0_dim1" :
445
+ _input_tensor_mx_dim0 , _input_tensor_mx_dim1 = cast_only_dim0_dim1 (
446
+ input_tensor ,
447
+ )
448
+ return
426
449
427
450
if enable_activation_checkpointing :
428
451
out = checkpoint (m_lowp , x , use_reentrant = False , context_fn = context_fn )
@@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
437
460
m_lowp = torch .compile (m_lowp , fullgraph = True )
438
461
to_mx_func = torch .compile (to_mx_func , fullgraph = True )
439
462
cast_with_to_blocked = torch .compile (cast_with_to_blocked , fullgraph = True )
463
+ cast_only_dim0_dim1 = torch .compile (cast_only_dim0_dim1 , fullgraph = True )
440
464
441
465
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
442
466
# to populate triton kernel bandwidth further down in the script
0 commit comments