Skip to content

add a benchmark for casting a tensor to MX across dim0 and dim1 #1787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,9 @@ def main(
"fwd",
"cast_only",
"cast_with_to_blocked",
"cast_only_dim0_dim1",
)
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`"
), "mode_filter must be one of `fwd_bwd`, `fwd`, `cast_only`, `cast_with_to_blocked`, `cast_only_dim0_dim1`"
if mode_filter == "cast_only":
assert experiment_filter == "lowp", "unsupported"

Expand Down Expand Up @@ -395,6 +396,23 @@ def cast_with_to_blocked(x_hp):
scale_blocked = to_blocked(x_mx._scale_e8m0.reshape(m, k // config.block_size))
return x_mx._data, scale_blocked

# this function is used for cast_only_dim0_dim1
def cast_only_dim0_dim1(x_hp):
x_hp_t_c = x_hp.t().contiguous()
x_mx_dim0 = MXTensor.to_mx(
x_hp,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
x_mx_dim1 = MXTensor.to_mx(
x_hp_t_c,
config.elem_dtype,
config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
return x_mx_dim0, x_mx_dim1

print("m_ref", m_ref)
print("m_lowp", m_lowp)
print("input_tensor.shape", input_tensor.shape)
Expand Down Expand Up @@ -423,6 +441,11 @@ def lowp_forw_backward_wrapper(x):
elif mode_filter == "cast_with_to_blocked":
_input_tensor_mx, scale = cast_with_to_blocked(input_tensor)
return
elif mode_filter == "cast_only_dim0_dim1":
_input_tensor_mx_dim0, _input_tensor_mx_dim1 = cast_only_dim0_dim1(
input_tensor,
)
return

if enable_activation_checkpointing:
out = checkpoint(m_lowp, x, use_reentrant=False, context_fn=context_fn)
Expand All @@ -437,6 +460,7 @@ def lowp_forw_backward_wrapper(x):
m_lowp = torch.compile(m_lowp, fullgraph=True)
to_mx_func = torch.compile(to_mx_func, fullgraph=True)
cast_with_to_blocked = torch.compile(cast_with_to_blocked, fullgraph=True)
cast_only_dim0_dim1 = torch.compile(cast_only_dim0_dim1, fullgraph=True)

# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
Expand Down
Loading