Skip to content

Commit 11e3d23

Browse files
committed
add a benchmark for casting a tensor to MX across dim0 and dim1
Summary: This is useful for training, extracting into a benchmark so we can optimize. Test Plan: ``` TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_only_dim0_dim1 // output: https://gist.github.com/vkuzo/a4e13bac7fc8ca3af10bfd5483b85b33 // currently we see two kernels, one per dim ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d8745ba ghstack-comment-id: 2686344197 Pull Request resolved: #1787
1 parent 8d110bf commit 11e3d23

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

benchmarks/float8/profile_lowp_training.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ def main(
306306
"fwd",
307307
"cast_only",
308308
"cast_with_to_blocked",
309+
"cast_only_dim0_dim1",
309310
)
310-
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
311+
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`"
311312
if mode_filter == "cast_only":
312313
assert experiment_filter == "lowp", "unsupported"
313314

@@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
395396
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
396397
return x_mx._data, scale_blocked
397398

399+
# this function is used for cast_only_dim0_dim1
400+
def cast_only_dim0_dim1(x_hp):
401+
x_hp_t_c = x_hp.t().contiguous()
402+
x_mx_dim0 = MXTensor.to_mx(
403+
x_hp,
404+
config.elem_dtype,
405+
config.block_size,
406+
gemm_kernel_choice=config.gemm_kernel_choice,
407+
)
408+
x_mx_dim1 = MXTensor.to_mx(
409+
x_hp_t_c,
410+
config.elem_dtype,
411+
config.block_size,
412+
gemm_kernel_choice=config.gemm_kernel_choice,
413+
)
414+
return x_mx_dim0, x_mx_dim1
415+
398416
print("m_ref", m_ref)
399417
print("m_lowp", m_lowp)
400418
print("input_tensor.shape", input_tensor.shape)
@@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
423441
elif mode_filter == "cast_with_to_blocked":
424442
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
425443
return
444+
elif mode_filter == "cast_only_dim0_dim1":
445+
_input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1(
446+
input_tensor,
447+
)
448+
return
426449

427450
if enable_activation_checkpointing:
428451
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
@@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
437460
m_lowp = torch.compile(m_lowp, fullgraph=True)
438461
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
439462
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
463+
cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True)
440464

441465
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
442466
# to populate triton kernel bandwidth further down in the script

0 commit comments

Comments
 (0)