Skip to content

Commit 0e03655

Browse files
committed
[wip] triton kernel to cast to mx and write in col-major
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4c77bd3 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
1 parent ab3792e commit 0e03655

File tree

3 files changed

+558
-1
lines changed

3 files changed

+558
-1
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton
66
from torch._inductor.utils import do_bench_using_profiling
77

8+
from torchao.prototype.mx_formats.custom_cast import (
9+
to_mxfp8_dim1,
10+
)
811
from torchao.prototype.mx_formats.mx_tensor import to_mx
912

1013
torch.manual_seed(0)
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
4952
return data_d0, scale_d0
5053

5154

55+
def to_mx_dim1_reference(x_hp, block_size):
56+
x_hp = x_hp.t().contiguous()
57+
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
58+
return data_d1.t(), scale_d1
59+
60+
5261
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
5362
"""Thin wrapper around do_bench_using_profiling"""
5463
no_args = lambda: func(*args, **kwargs)
@@ -67,7 +76,7 @@ def run(
6776
print(f"torch version: {torch.__version__}")
6877
print(f"triton version: {triton.__version__}")
6978
print(f"mode: {mode}")
70-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
79+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
7180

7281
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
7382

@@ -144,6 +153,41 @@ def run(
144153
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
145154
bps = (bytes_r + bytes_w) / (time_us / 1e6)
146155

156+
elif mode == "dim1_mx":
157+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
158+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
159+
160+
for _ in range(2):
161+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
162+
time_us = benchmark_cuda_function_in_microseconds(
163+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
164+
x,
165+
BLOCK_SIZE,
166+
)
167+
168+
assert y_d1.dtype == torch.float8_e4m3fn
169+
assert s_d1.dtype == torch.uint8
170+
bytes_r = x.numel() * bytes_per_el_bf16
171+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173+
174+
elif mode == "dim1_mx_triton":
175+
y_d1, s_d1 = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
176+
177+
for _ in range(2):
178+
__ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
179+
time_us = benchmark_cuda_function_in_microseconds(
180+
lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
181+
x,
182+
BLOCK_SIZE,
183+
)
184+
185+
assert y_d1.dtype == torch.float8_e4m3fn
186+
assert s_d1.dtype == torch.float8_e8m0fnu
187+
bytes_r = x.numel() * bytes_per_el_bf16
188+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
189+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
190+
147191
else:
148192
raise AssertionError(f"unknown mode {mode}")
149193

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
get_bits,
2727
pack_uint4,
2828
pack_uint6,
29+
# TODO(before land): better name?
30+
to_mxfp8_dim1,
31+
to_mxfp8_dim1_reference,
2932
triton_f4_to_bf16,
3033
triton_f6_e2m3_to_bf16,
3134
triton_f6_e3m2_to_bf16,
@@ -444,3 +447,16 @@ def test_fp6_e3m2_pack_unpack():
444447
torch.float32
445448
)
446449
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
450+
451+
452+
# TODO(before land): skip before sm89
453+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
454+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
455+
def test_triton_mxfp8_dim1():
456+
M, K = 1024, 2048
457+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
458+
x_mx_ref, x_s_ref = to_mxfp8_dim1_reference(x, block_size=32)
459+
x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32)
460+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
461+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
462+
print("done")

0 commit comments

Comments
 (0)