Skip to content

Commit cb08249

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 7235553 ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent be09c1d commit cb08249

File tree

3 files changed

+691
-0
lines changed

3 files changed

+691
-0
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable
13+
14+
import fire
15+
import torch
16+
import triton
17+
from torch._inductor.utils import do_bench_using_profiling
18+
19+
from torchao.prototype.mx_formats.custom_cast import (
20+
to_mxfp8_across_dim0_and_dim1,
21+
to_mxfp8_across_dim0_and_dim1_reference,
22+
to_mxfp8_dim0_reference,
23+
to_mxfp8_dim1,
24+
to_mxfp8_dim1_reference,
25+
)
26+
from torchao.testing.float8.roofline_utils import get_specs
27+
28+
torch.manual_seed(0)
29+
30+
bytes_per_el_bf16 = 2
31+
bytes_per_el_fp8 = 1
32+
33+
34+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
35+
"""Thin wrapper around do_bench_using_profiling"""
36+
no_args = lambda: func(*args, **kwargs)
37+
time = do_bench_using_profiling(no_args)
38+
return time * 1e3
39+
40+
41+
def compute_error(x, y):
42+
Ps = torch.linalg.norm(x)
43+
Pn = torch.linalg.norm(x - y)
44+
return 20 * torch.log10(Ps / Pn)
45+
46+
47+
def run(
48+
M: int = 4096,
49+
K: int = 2048,
50+
BLOCK_SIZE: int = 32,
51+
check_accuracy: bool = True,
52+
mode: str = "dim0_and_dim1",
53+
):
54+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
55+
print(f"GPU: {torch.cuda.get_device_name(0)}")
56+
print(f"torch version: {torch.__version__}")
57+
print(f"triton version: {triton.__version__}")
58+
print(f"mode: {mode}")
59+
assert mode in "dim0_and_dim1", "dim1"
60+
61+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
62+
63+
to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile(
64+
to_mxfp8_across_dim0_and_dim1_reference
65+
)
66+
to_mxfp8_dim0_reference_c = torch.compile(to_mxfp8_dim0_reference)
67+
to_mxfp8_dim1_reference_c = torch.compile(to_mxfp8_dim1_reference)
68+
69+
# reference implementation (plain PyTorch + torch.compile)
70+
if mode == "dim0_and_dim1":
71+
# TODO remove the mode here?
72+
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = (
73+
to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
74+
)
75+
else: # dim1
76+
x_d0, scale_e8m0_d0 = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
77+
x_d1, scale_e8m0_d1 = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
78+
79+
x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16()
80+
scale_fp_d0 = scale_e8m0_d0.float()
81+
scale_fp_d1 = scale_e8m0_d1.float()
82+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
83+
x_d1_and_back = (
84+
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
85+
)
86+
87+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
88+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
89+
print(
90+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
91+
)
92+
assert (
93+
sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28
94+
), "reference mx numerics are incorrect"
95+
96+
# triton kernel for dim1 only
97+
x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
98+
99+
# triton kernel for dim0 and dim1
100+
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1(
101+
x, tile_size=BLOCK_SIZE
102+
)
103+
x_d0_t, x_d1_t, x_d1_only_t = (
104+
x_d0_t.bfloat16(),
105+
x_d1_t.bfloat16(),
106+
x_d1_only_t.bfloat16(),
107+
)
108+
109+
# ensure bitwise equivalency of outputs with reference
110+
if check_accuracy:
111+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
112+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
113+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
114+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
115+
torch.testing.assert_close(x_d1, x_d1_only_t, atol=0, rtol=0)
116+
# print('reference', scale_e8m0_d1)
117+
# print('triton', scale_e8m0_d1_only_t)
118+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_only_t, atol=0, rtol=0)
119+
print("normalized reference vs normalized triton are bitwise equivalent")
120+
# return
121+
else:
122+
print("accuracy checking skipped")
123+
124+
# now, measure performance
125+
126+
for _ in range(2):
127+
__ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
128+
time_ref_dim0_dim1_compile_us = benchmark_cuda_function_in_microseconds(
129+
lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE
130+
)
131+
132+
for _ in range(2):
133+
__ = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
134+
time_ref_dim0_compile_us = benchmark_cuda_function_in_microseconds(
135+
lambda x, b: to_mxfp8_dim0_reference_c(x, b), x, BLOCK_SIZE
136+
)
137+
138+
for _ in range(2):
139+
__ = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
140+
time_ref_dim1_compile_us = benchmark_cuda_function_in_microseconds(
141+
lambda x, b: to_mxfp8_dim1_reference_c(x, b), x, BLOCK_SIZE
142+
)
143+
144+
for _ in range(2):
145+
__ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
146+
time_triton_dim1_us = benchmark_cuda_function_in_microseconds(
147+
lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE),
148+
x,
149+
BLOCK_SIZE,
150+
)
151+
152+
# warm up
153+
for _ in range(2):
154+
__ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE)
155+
time_triton_dim0_dim1_us = benchmark_cuda_function_in_microseconds(
156+
lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE),
157+
x,
158+
BLOCK_SIZE,
159+
)
160+
161+
# calculate memory bandwidth
162+
peak_mem_bw = get_specs()["peak_mem_bw_bytes_sec"]
163+
164+
# dim0 or dim1 kernel
165+
dim0_bytes_read = x.numel() * bytes_per_el_bf16
166+
dim0_bytes_written = (x_d0_t.numel() + scale_e8m0_d0_t.numel()) * bytes_per_el_fp8
167+
dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written
168+
ref_dim0_bps = dim0_bytes_rw / (time_ref_dim0_compile_us / 1e6)
169+
ref_dim1_bps = dim0_bytes_rw / (time_ref_dim1_compile_us / 1e6)
170+
triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6)
171+
172+
# triton dim0_dim1 kernel
173+
triton_dim0_dim1_bytes_read = x.numel() * bytes_per_el_bf16
174+
triton_dim0_dim1_bytes_written = (
175+
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
176+
* bytes_per_el_fp8
177+
)
178+
triton_dim0_dim1_bps = (
179+
triton_dim0_dim1_bytes_read + triton_dim0_dim1_bytes_written
180+
) / (time_triton_dim0_dim1_us / 1e6)
181+
triton_dim0_dim1_pct_peak_mem = triton_dim0_dim1_bps / peak_mem_bw
182+
183+
print("time_ref_dim0_dim1_compile_us", time_ref_dim0_dim1_compile_us)
184+
print("time_ref_dim0_compile_us", time_ref_dim0_compile_us)
185+
print("time_ref_dim1_compile_us", time_ref_dim1_compile_us)
186+
print("time_triton_dim1_us", time_triton_dim1_us)
187+
print("time_triton_dim0_dim1_us", time_triton_dim0_dim1_us)
188+
print("ref_dim0_mem_bw_gbps", ref_dim0_bps / 1e9)
189+
print("ref_dim1_mem_bw_gbps", ref_dim1_bps / 1e9)
190+
print("triton_dim1_mem_bw_gbps", triton_dim1_bps / 1e9)
191+
print("triton_dim0_dim1_mem_bw_gbps", triton_dim0_dim1_bps / 1e9)
192+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
193+
# can achieve around 50-70% of B200 peak mem bw
194+
print("triton_dim0_dim1_pct_peak_mem", triton_dim0_dim1_pct_peak_mem)
195+
print("dim0_dim1 speedup", time_ref_dim0_dim1_compile_us / time_triton_dim0_dim1_us)
196+
197+
198+
if __name__ == "__main__":
199+
fire.Fire(run)

benchmarks/mx_formats/mx_dim1_cast.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable
13+
14+
import fire
15+
import torch
16+
import triton
17+
from torch._inductor.utils import do_bench_using_profiling
18+
19+
from torchao.prototype.mx_formats.custom_cast import (
20+
to_mxfp8_dim1,
21+
)
22+
23+
torch.manual_seed(0)
24+
25+
bytes_per_el_bf16 = 2
26+
bytes_per_el_fp8 = 1
27+
28+
29+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
30+
"""Thin wrapper around do_bench_using_profiling"""
31+
no_args = lambda: func(*args, **kwargs)
32+
time = do_bench_using_profiling(no_args)
33+
return time * 1e3
34+
35+
36+
def run(
37+
M: int = 4096,
38+
K: int = 2048,
39+
BLOCK_SIZE: int = 32,
40+
):
41+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
42+
print(f"GPU: {torch.cuda.get_device_name(0)}")
43+
print(f"torch version: {torch.__version__}")
44+
print(f"triton version: {triton.__version__}")
45+
46+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
47+
48+
x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
49+
50+
for _ in range(2):
51+
__ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
52+
time_triton_dim1_us = benchmark_cuda_function_in_microseconds(
53+
lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE),
54+
x,
55+
BLOCK_SIZE,
56+
)
57+
58+
dim0_bytes_read = x.numel() * bytes_per_el_bf16
59+
dim0_bytes_written = (
60+
x_d1_only_t.numel() + scale_e8m0_d1_only_t.numel()
61+
) * bytes_per_el_fp8
62+
dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written
63+
triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6)
64+
65+
print("time_triton_dim1_us", time_triton_dim1_us)
66+
print("triton_dim1_mem_bw_gbps", triton_dim1_bps / 1e9)
67+
68+
69+
if __name__ == "__main__":
70+
fire.Fire(run)

0 commit comments

Comments
 (0)