Skip to content

mx: triton kernel to cast to mx and write in col-major #1932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import triton
from torch._inductor.utils import do_bench_using_profiling

from torchao.prototype.mx_formats.custom_cast import (
triton_to_mxfp8_dim1,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

torch.manual_seed(0)
Expand Down Expand Up @@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
return data_d0, scale_d0


def to_mx_dim1_reference(x_hp, block_size):
x_hp = x_hp.t().contiguous()
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
return data_d1.t(), scale_d1


def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
Expand All @@ -67,7 +76,7 @@ def run(
print(f"torch version: {torch.__version__}")
print(f"triton version: {triton.__version__}")
print(f"mode: {mode}")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000

Expand Down Expand Up @@ -144,6 +153,41 @@ def run(
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx":
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)

for _ in range(2):
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
x,
BLOCK_SIZE,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.uint8
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

elif mode == "dim1_mx_triton":
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)

for _ in range(2):
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
time_us = benchmark_cuda_function_in_microseconds(
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
x,
BLOCK_SIZE,
)

assert y_d1.dtype == torch.float8_e4m3fn
assert s_d1.dtype == torch.float8_e8m0fnu
bytes_r = x.numel() * bytes_per_el_bf16
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
bps = (bytes_r + bytes_w) / (time_us / 1e6)

else:
raise AssertionError(f"unknown mode {mode}")

Expand Down
25 changes: 24 additions & 1 deletion test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
triton_to_mxfp8_dim1,
triton_to_mxfp8_dim1_reference,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
Expand All @@ -42,7 +44,11 @@
sem_vals_to_f32,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_100,
)

torch.manual_seed(0)

Expand Down Expand Up @@ -444,3 +450,20 @@ def test_fp6_e3m2_pack_unpack():
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_89(),
reason="float8 in triton requires CUDA capability 8.9 or greater",
)
@pytest.mark.parametrize("M", (256, 2048))
@pytest.mark.parametrize("K", (256, 2048))
# @pytest.mark.parametrize("M", (256,))
# @pytest.mark.parametrize("K", (256,))
def test_triton_mxfp8_dim1(M, K):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32)
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
Loading
Loading