Skip to content

Commit beb91de

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6d14bec ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent ddb7f83 commit beb91de

File tree

3 files changed

+743
-1
lines changed

3 files changed

+743
-1
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton
66
from torch._inductor.utils import do_bench_using_profiling
77

8+
from torchao.prototype.mx_formats.custom_cast import (
9+
to_mxfp8_dim1,
10+
)
811
from torchao.prototype.mx_formats.mx_tensor import to_mx
912

1013
torch.manual_seed(0)
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
4952
return data_d0, scale_d0
5053

5154

55+
def to_mx_dim1_reference(x_hp, block_size):
56+
x_hp = x_hp.t().contiguous()
57+
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
58+
return data_d1.t(), scale_d1
59+
60+
5261
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
5362
"""Thin wrapper around do_bench_using_profiling"""
5463
no_args = lambda: func(*args, **kwargs)
@@ -67,7 +76,7 @@ def run(
6776
print(f"torch version: {torch.__version__}")
6877
print(f"triton version: {triton.__version__}")
6978
print(f"mode: {mode}")
70-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
79+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
7180

7281
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
7382

@@ -144,6 +153,41 @@ def run(
144153
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
145154
bps = (bytes_r + bytes_w) / (time_us / 1e6)
146155

156+
elif mode == "dim1_mx":
157+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
158+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
159+
160+
for _ in range(2):
161+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
162+
time_us = benchmark_cuda_function_in_microseconds(
163+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
164+
x,
165+
BLOCK_SIZE,
166+
)
167+
168+
assert y_d1.dtype == torch.float8_e4m3fn
169+
assert s_d1.dtype == torch.uint8
170+
bytes_r = x.numel() * bytes_per_el_bf16
171+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173+
174+
elif mode == "dim1_mx_triton":
175+
y_d1, s_d1 = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
176+
177+
for _ in range(2):
178+
__ = to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE)
179+
time_us = benchmark_cuda_function_in_microseconds(
180+
lambda x, b: to_mxfp8_dim1(x, row_tile_size=BLOCK_SIZE),
181+
x,
182+
BLOCK_SIZE,
183+
)
184+
185+
assert y_d1.dtype == torch.float8_e4m3fn
186+
assert s_d1.dtype == torch.float8_e8m0fnu
187+
bytes_r = x.numel() * bytes_per_el_bf16
188+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
189+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
190+
147191
else:
148192
raise AssertionError(f"unknown mode {mode}")
149193

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Callable
8+
9+
import fire
10+
import pandas as pd
11+
import torch
12+
import triton
13+
from torch._inductor.utils import do_bench_using_profiling
14+
15+
from torchao.prototype.mx_formats.custom_cast import (
16+
to_mxfp8_across_dim0_and_dim1,
17+
to_mxfp8_across_dim0_and_dim1_reference,
18+
to_mxfp8_dim0_reference,
19+
to_mxfp8_dim1,
20+
to_mxfp8_dim1_reference,
21+
)
22+
from torchao.quantization.utils import compute_error
23+
from torchao.testing.float8.roofline_utils import get_specs
24+
25+
torch.manual_seed(0)
26+
27+
bytes_per_el_bf16 = 2
28+
bytes_per_el_fp8 = 1
29+
30+
31+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
32+
"""Thin wrapper around do_bench_using_profiling"""
33+
no_args = lambda: func(*args, **kwargs)
34+
time = do_bench_using_profiling(no_args)
35+
return time * 1e3
36+
37+
38+
def run(
39+
M: int = 4096,
40+
K: int = 2048,
41+
BLOCK_SIZE: int = 32,
42+
check_accuracy: bool = True,
43+
):
44+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
45+
print(f"GPU: {torch.cuda.get_device_name(0)}")
46+
print(f"torch version: {torch.__version__}")
47+
print(f"triton version: {triton.__version__}")
48+
49+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
50+
51+
to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile(
52+
to_mxfp8_across_dim0_and_dim1_reference
53+
)
54+
to_mxfp8_dim0_reference_c = torch.compile(to_mxfp8_dim0_reference)
55+
to_mxfp8_dim1_reference_c = torch.compile(to_mxfp8_dim1_reference)
56+
57+
# reference implementation (plain PyTorch + torch.compile)
58+
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = (
59+
to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
60+
)
61+
62+
# verify reference dim0_dim1 matches dim0 and dim1 separately
63+
x_d0_separate, scale_e8m0_d0_separate = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
64+
x_d1_separate, scale_e8m0_d1_separate = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
65+
torch.testing.assert_close(x_d0, x_d0_separate, atol=0, rtol=0)
66+
torch.testing.assert_close(x_d1, x_d1_separate, atol=0, rtol=0)
67+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_separate, atol=0, rtol=0)
68+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_separate, atol=0, rtol=0)
69+
70+
x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16()
71+
scale_fp_d0 = scale_e8m0_d0.float()
72+
scale_fp_d1 = scale_e8m0_d1.float()
73+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
74+
x_d1_and_back = (
75+
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
76+
)
77+
78+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
79+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
80+
print(
81+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
82+
)
83+
assert (
84+
sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28
85+
), "reference mx numerics are incorrect"
86+
87+
# triton kernel for dim1 only
88+
x_d1_only_t, scale_e8m0_d1_only_t = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
89+
90+
# triton kernel for dim0 and dim1
91+
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1(
92+
x, tile_size=BLOCK_SIZE
93+
)
94+
x_d0_t, x_d1_t, x_d1_only_t = (
95+
x_d0_t.bfloat16(),
96+
x_d1_t.bfloat16(),
97+
x_d1_only_t.bfloat16(),
98+
)
99+
100+
# ensure bitwise equivalency of outputs with reference
101+
if check_accuracy:
102+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
103+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
104+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
105+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
106+
print("reference vs triton dim0_dim1 are bitwise equivalent")
107+
torch.testing.assert_close(x_d1, x_d1_only_t, atol=0, rtol=0)
108+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_only_t, atol=0, rtol=0)
109+
print("reference vs triton dim1 are bitwise equivalent")
110+
else:
111+
print("accuracy checking skipped")
112+
113+
# now, measure performance
114+
115+
# define a speed-of-light torch.compile kernel to get a sense of
116+
# achievable mem bandwidth
117+
def add_one(x):
118+
x = x + 1
119+
return x
120+
121+
add_one_c = torch.compile(add_one)
122+
123+
for _ in range(2):
124+
__ = add_one_c(x)
125+
time_add_one_compile_us = benchmark_cuda_function_in_microseconds(
126+
lambda x: add_one(x), x
127+
)
128+
129+
for _ in range(2):
130+
__ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
131+
time_ref_dim0_dim1_compile_us = benchmark_cuda_function_in_microseconds(
132+
lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE
133+
)
134+
135+
for _ in range(2):
136+
__ = to_mxfp8_dim0_reference_c(x, BLOCK_SIZE)
137+
time_ref_dim0_compile_us = benchmark_cuda_function_in_microseconds(
138+
lambda x, b: to_mxfp8_dim0_reference_c(x, b), x, BLOCK_SIZE
139+
)
140+
141+
for _ in range(2):
142+
__ = to_mxfp8_dim1_reference_c(x, BLOCK_SIZE)
143+
time_ref_dim1_compile_us = benchmark_cuda_function_in_microseconds(
144+
lambda x, b: to_mxfp8_dim1_reference_c(x, b), x, BLOCK_SIZE
145+
)
146+
147+
for _ in range(2):
148+
__ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
149+
time_triton_dim1_us = benchmark_cuda_function_in_microseconds(
150+
lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
151+
x,
152+
BLOCK_SIZE,
153+
)
154+
155+
# warm up
156+
for _ in range(2):
157+
__ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE)
158+
time_triton_dim0_dim1_us = benchmark_cuda_function_in_microseconds(
159+
lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE),
160+
x,
161+
BLOCK_SIZE,
162+
)
163+
164+
# calculate memory bandwidth
165+
peak_mem_bw = get_specs()["peak_mem_bw_bytes_sec"]
166+
167+
# add_one kernel
168+
add_one_bps = x.numel() * bytes_per_el_bf16 * 2 / (time_add_one_compile_us / 1e6)
169+
170+
# dim0 or dim1 kernel
171+
dim0_bytes_read = x.numel() * bytes_per_el_bf16
172+
dim0_bytes_written = (x_d0_t.numel() + scale_e8m0_d0_t.numel()) * bytes_per_el_fp8
173+
dim0_bytes_rw = dim0_bytes_read + dim0_bytes_written
174+
ref_dim0_bps = dim0_bytes_rw / (time_ref_dim0_compile_us / 1e6)
175+
ref_dim1_bps = dim0_bytes_rw / (time_ref_dim1_compile_us / 1e6)
176+
triton_dim1_bps = dim0_bytes_rw / (time_triton_dim1_us / 1e6)
177+
178+
# triton dim0_dim1 kernel
179+
triton_dim0_dim1_bytes_read = x.numel() * bytes_per_el_bf16
180+
triton_dim0_dim1_bytes_written = (
181+
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
182+
* bytes_per_el_fp8
183+
)
184+
triton_dim0_dim1_bps = (
185+
triton_dim0_dim1_bytes_read + triton_dim0_dim1_bytes_written
186+
) / (time_triton_dim0_dim1_us / 1e6)
187+
188+
results = [
189+
["add_one", time_add_one_compile_us, add_one_bps / 1e9],
190+
["compile_dim0", time_ref_dim0_compile_us, ref_dim0_bps / 1e9],
191+
["compile_dim1", time_ref_dim1_compile_us, ref_dim1_bps / 1e9],
192+
[
193+
"compile_dim0_dim1",
194+
time_ref_dim0_dim1_compile_us,
195+
triton_dim0_dim1_bps / 1e9,
196+
],
197+
["triton_dim1", time_triton_dim1_us, triton_dim1_bps / 1e9],
198+
["triton_dim0_dim1", time_triton_dim0_dim1_us, triton_dim0_dim1_bps / 1e9],
199+
]
200+
df = pd.DataFrame(results, columns=["experiment", "time_us", "mem_bw_gbps"])
201+
df["mem_bw_pct_peak"] = df["mem_bw_gbps"] * 1e9 / peak_mem_bw
202+
print("\n", df)
203+
204+
205+
if __name__ == "__main__":
206+
fire.Fire(run)

0 commit comments

Comments
 (0)