Skip to content

Commit ae231dd

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 846ae8c ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent d3b481b commit ae231dd

File tree

2 files changed

+371
-0
lines changed

2 files changed

+371
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable
13+
14+
import fire
15+
import torch
16+
import triton
17+
from torch._inductor.utils import do_bench_using_profiling
18+
19+
from torchao.prototype.mx_formats.custom_cast import (
20+
to_mxfp8_across_dim0_and_dim1,
21+
to_mxfp8_across_dim0_and_dim1_reference,
22+
)
23+
24+
torch.manual_seed(0)
25+
26+
27+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
28+
"""Thin wrapper around do_bench_using_profiling"""
29+
no_args = lambda: func(*args, **kwargs)
30+
time = do_bench_using_profiling(no_args)
31+
return time * 1e3
32+
33+
34+
def compute_error(x, y):
35+
Ps = torch.linalg.norm(x)
36+
Pn = torch.linalg.norm(x - y)
37+
return 20 * torch.log10(Ps / Pn)
38+
39+
40+
def run(
41+
M: int = 4096,
42+
K: int = 2048,
43+
BLOCK_SIZE: int = 32,
44+
):
45+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
46+
print(f"GPU: {torch.cuda.get_device_name(0)}")
47+
print(f"torch version: {torch.__version__}")
48+
print(f"triton version: {triton.__version__}")
49+
50+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
51+
52+
to_mxfp8_across_dim0_and_dim1_reference_c = torch.compile(
53+
to_mxfp8_across_dim0_and_dim1_reference
54+
)
55+
56+
# reference implementation (plain PyTorch + torch.compile)
57+
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = (
58+
to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
59+
)
60+
x_d0, x_d1 = x_d0.bfloat16(), x_d1.bfloat16()
61+
scale_fp_d0 = scale_e8m0_d0.float()
62+
scale_fp_d1 = scale_e8m0_d1.float()
63+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
64+
x_d1_and_back = (
65+
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
66+
)
67+
68+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
69+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
70+
print(
71+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
72+
)
73+
assert (
74+
sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28
75+
), "reference mx numerics are incorrect"
76+
77+
# basic triton kernel
78+
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = to_mxfp8_across_dim0_and_dim1(
79+
x, tile_size=BLOCK_SIZE
80+
)
81+
x_d0_t, x_d1_t = x_d0_t.bfloat16(), x_d1_t.bfloat16()
82+
83+
# ensure bitwise equivalency of outputs with reference
84+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
85+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
86+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
87+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
88+
print("normalized reference vs normalized triton are bitwise equivalent")
89+
90+
# now, measure performance
91+
92+
# warm up
93+
for _ in range(2):
94+
__ = to_mxfp8_across_dim0_and_dim1_reference_c(x, BLOCK_SIZE)
95+
time_reference_compile_us = benchmark_cuda_function_in_microseconds(
96+
lambda x, b: to_mxfp8_across_dim0_and_dim1_reference_c(x, b), x, BLOCK_SIZE
97+
)
98+
99+
# warm up
100+
for _ in range(2):
101+
__ = to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE)
102+
time_triton_us = benchmark_cuda_function_in_microseconds(
103+
lambda x, b: to_mxfp8_across_dim0_and_dim1(x, tile_size=BLOCK_SIZE),
104+
x,
105+
BLOCK_SIZE,
106+
)
107+
108+
# calculate bytes read/written
109+
bytes_per_el_bf16 = 2
110+
bytes_per_el_fp8 = 1
111+
triton_bytes_read = x.numel() * bytes_per_el_bf16
112+
triton_bytes_written = (
113+
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
114+
* bytes_per_el_fp8
115+
)
116+
triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / (
117+
time_triton_us / 1e6
118+
)
119+
# TODO(future PR): read 8.0 TB/s number from roofline_utils.py instead of hardcoding
120+
triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12
121+
122+
print("time_reference_compile_us", time_reference_compile_us)
123+
print("time_triton_us", time_triton_us)
124+
print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps)
125+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
126+
# can achieve around 50-70% of B200 peak mem bw
127+
print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw)
128+
print("speedup", time_reference_compile_us / time_triton_us)
129+
130+
131+
if __name__ == "__main__":
132+
fire.Fire(run)

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
import numpy as np
810
import torch
911
from torch.utils._triton import has_triton
@@ -1080,3 +1082,240 @@ def _(uint8_data):
10801082
def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
10811083
# Dummy placeholder op for torch < 2.4
10821084
raise AssertionError("fp6 packing unsupported without torch >= 2.4")
1085+
1086+
1087+
if TORCH_VERSION_AT_LEAST_2_4 and has_triton():
1088+
import triton
1089+
import triton.language as tl
1090+
1091+
@triton.jit
1092+
def _triton_calculate_scale(x, axis):
1093+
# We use a small epsilon to avoid division by zero
1094+
epsilon = 1e-10
1095+
1096+
# TODO(before land): reuse the constants below instead of hardcoding
1097+
target_max_pow2 = 8
1098+
e8m0_exponent_bias = 127
1099+
1100+
# Find the maximum absolute value for each row
1101+
max_abs = tl.max(x, axis=axis)
1102+
1103+
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
1104+
1105+
# Clamp to exponents that can be represented in e8m0
1106+
scale_e8m0_unbiased = tl.clamp(
1107+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
1108+
)
1109+
1110+
# Create the biased e8m0 representation and cast it to 8 bits
1111+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
1112+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
1113+
1114+
# TODO(future PR): add NaN handling here
1115+
1116+
# For now, calculate the scale in floating point.
1117+
# TODO(future) audit if there is a need to bit shift exponents instead.
1118+
scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32))
1119+
1120+
return scale_fp, scale_e8m0_biased
1121+
1122+
@triton.jit
1123+
def to_mxfp8_across_dim0_and_dim1_kernel(
1124+
x_ptr, # pointer to input tensor
1125+
output_row_major_ptr, # pointer to row-major output tensor (row-normalized)
1126+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
1127+
row_scale_ptr, # pointer to store row-wise maximum absolute values
1128+
col_scale_ptr, # pointer to store column-wise maximum absolute values
1129+
n_rows, # number of rows in the tensor
1130+
n_cols, # number of columns in the tensor
1131+
TILE_SIZE: tl.constexpr, # tile size as a compile-time constant
1132+
):
1133+
"""
1134+
credit: mostly Claude, some Vasiliy
1135+
"""
1136+
1137+
# Get program ID
1138+
pid_row = tl.program_id(0)
1139+
pid_col = tl.program_id(1)
1140+
1141+
# Calculate starting row and column for this tile
1142+
start_row = pid_row * TILE_SIZE
1143+
start_col = pid_col * TILE_SIZE
1144+
1145+
# Create offsets for the block
1146+
row_offsets = tl.arange(0, TILE_SIZE)
1147+
col_offsets = tl.arange(0, TILE_SIZE)
1148+
1149+
# Compute global row/col positions
1150+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
1151+
cols = start_col + col_offsets[None, :]
1152+
1153+
# Create masks for out-of-bounds accesses
1154+
row_mask = rows < n_rows
1155+
col_mask = cols < n_cols
1156+
mask = row_mask & col_mask
1157+
1158+
# Compute memory offsets for row-major layout (rows, cols)
1159+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1160+
1161+
# Compute memory offsets for column-major layout (cols, rows)
1162+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
1163+
1164+
# Load the entire block in a single operation
1165+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1166+
1167+
# ----------------------------------------------------
1168+
# Row-wise normalization
1169+
# ----------------------------------------------------
1170+
# Calculate the absolute values of elements in the block
1171+
x_block_abs = tl.abs(x_block)
1172+
1173+
# Find the maximum absolute value for each row
1174+
row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1)
1175+
1176+
# Normalize each row by its maximum absolute value
1177+
# Broadcasting row_scale to match x_block's shape
1178+
row_normalized = x_block / row_scale[:, None]
1179+
1180+
# quant to float8
1181+
# TODO(this PR): clamp?
1182+
row_normalized = row_normalized.to(tl.float8e4nv)
1183+
1184+
# ----------------------------------------------------
1185+
# Column-wise normalization
1186+
# ----------------------------------------------------
1187+
# Find the maximum absolute value for each column
1188+
col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0)
1189+
1190+
# Normalize each column by its maximum absolute value
1191+
# Broadcasting col_scale to match x_block's shape
1192+
col_normalized = x_block / col_scale[None, :]
1193+
1194+
# quant to float8
1195+
# TODO(this PR): clamp?
1196+
col_normalized = col_normalized.to(tl.float8e4nv)
1197+
1198+
# Store the row-normalized result in row-major format
1199+
tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask)
1200+
1201+
# Store the column-normalized result in column-major format
1202+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
1203+
1204+
# Create 1D ranges for storing row and column max values
1205+
row_indices = start_row + tl.arange(0, TILE_SIZE)
1206+
col_indices = start_col + tl.arange(0, TILE_SIZE)
1207+
1208+
# Create masks for valid rows and columns
1209+
row_mask = row_indices < n_rows
1210+
col_mask = col_indices < n_cols
1211+
1212+
# Vasiliy - deviating from Claude here for much simpler code
1213+
row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col
1214+
row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE)
1215+
# TODO(future): mask
1216+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
1217+
1218+
# Vasiliy - deviating from Claude here for much simpler code
1219+
col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row
1220+
col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE)
1221+
# TODO(future): mask
1222+
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
1223+
1224+
def to_mxfp8_across_dim0_and_dim1(x, tile_size=32):
1225+
"""
1226+
This is a single fused triton kernel to cast `x` to MX across dim0 and dim1.
1227+
This is useful for MX training with the mxfp8 recipe family.
1228+
1229+
The kernel loads data in 2d tiles, and performs the necessary casting across both
1230+
dim0 and dim1 for each tile.
1231+
1232+
Note that for now, there is only one level of tiling (32 for MX). In the future,
1233+
we expect that adding an outer tile (of size up to 128 on B200s) can provide a
1234+
further speedup.
1235+
1236+
Input:
1237+
* `x` - input tensor, in row major memory layout
1238+
* `tile_size` - size of tiles to normalize across, default is 32 for MX recipes
1239+
1240+
Output:
1241+
* `output_row_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0
1242+
* `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1
1243+
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1244+
* `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1
1245+
"""
1246+
assert x.is_contiguous(), "`x` must be contiguous"
1247+
# Get tensor shape
1248+
n_rows, n_cols = x.shape
1249+
1250+
# Create output tensors (both row-major and column-major)
1251+
output_row_major = torch.empty_like(x, dtype=torch.float8_e4m3fn)
1252+
output_col_major = torch.empty(
1253+
(n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device
1254+
)
1255+
1256+
# Create tensors for row-wise and column-wise maximum absolute values
1257+
row_scale = torch.empty(
1258+
n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device
1259+
)
1260+
col_scale = torch.empty(
1261+
n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device
1262+
)
1263+
1264+
# Calculate grid dimensions based on tile size
1265+
grid_rows = triton.cdiv(n_rows, tile_size)
1266+
grid_cols = triton.cdiv(n_cols, tile_size)
1267+
1268+
# Launch the kernel
1269+
to_mxfp8_across_dim0_and_dim1_kernel[(grid_rows, grid_cols)](
1270+
x_ptr=x,
1271+
output_row_major_ptr=output_row_major,
1272+
output_col_major_ptr=output_col_major,
1273+
row_scale_ptr=row_scale,
1274+
col_scale_ptr=col_scale,
1275+
n_rows=n_rows,
1276+
n_cols=n_cols,
1277+
TILE_SIZE=tile_size,
1278+
)
1279+
1280+
return (
1281+
output_row_major,
1282+
output_col_major.t(),
1283+
row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
1284+
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
1285+
)
1286+
1287+
def to_mxfp8_across_dim0_and_dim1_reference(
1288+
x_hp: torch.Tensor, block_size
1289+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1290+
"""
1291+
A reference version of `to_mxfp8_across_dim0_and_dim1`.
1292+
"""
1293+
from torchao.prototype.mx_formats.mx_tensor import to_mx
1294+
1295+
# cast across dim0
1296+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(
1297+
x_hp, torch.float8_e4m3fn, block_size
1298+
)
1299+
scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu)
1300+
# cast across dim1
1301+
x_hp_d1 = x_hp.t().contiguous()
1302+
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
1303+
x_hp_d1, torch.float8_e4m3fn, block_size
1304+
)
1305+
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
1306+
return (
1307+
x_hp_d0_normalized,
1308+
x_hp_d1_normalized.t(),
1309+
scale_e8m0_dim0,
1310+
scale_e8m0_dim1,
1311+
)
1312+
1313+
else:
1314+
1315+
def to_mxfp8_across_dim0_and_dim1(x, tile_size=32):
1316+
raise AssertionError("needs torch version 2.4+ and triton")
1317+
1318+
def scale_dim0_dim1_reference(
1319+
x_hp: torch.Tensor, block_size
1320+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1321+
raise AssertionError("needs torch version 2.4+ and triton")

0 commit comments

Comments
 (0)