Skip to content

Commit 8801d1f

Browse files
committed
[wip] triton kernel to cast to mx and write in col-major
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8486913 ghstack-comment-id: 2743450537 Pull Request resolved: #1932
1 parent ab3792e commit 8801d1f

File tree

3 files changed

+325
-3
lines changed

3 files changed

+325
-3
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
import triton
66
from torch._inductor.utils import do_bench_using_profiling
77

8+
from torchao.prototype.mx_formats.custom_cast import (
9+
to_mxfp8_dim1,
10+
)
811
from torchao.prototype.mx_formats.mx_tensor import to_mx
912

1013
torch.manual_seed(0)
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
4952
return data_d0, scale_d0
5053

5154

55+
def to_mx_dim1_reference(x_hp, block_size):
56+
x_hp = x_hp.t().contiguous()
57+
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
58+
return data_d1.t(), scale_d1
59+
60+
5261
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
5362
"""Thin wrapper around do_bench_using_profiling"""
5463
no_args = lambda: func(*args, **kwargs)
@@ -67,7 +76,7 @@ def run(
6776
print(f"torch version: {torch.__version__}")
6877
print(f"triton version: {triton.__version__}")
6978
print(f"mode: {mode}")
70-
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
79+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
7180

7281
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
7382

@@ -144,6 +153,41 @@ def run(
144153
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
145154
bps = (bytes_r + bytes_w) / (time_us / 1e6)
146155

156+
elif mode == "dim1_mx":
157+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
158+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
159+
160+
for _ in range(2):
161+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
162+
time_us = benchmark_cuda_function_in_microseconds(
163+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
164+
x,
165+
BLOCK_SIZE,
166+
)
167+
168+
assert y_d1.dtype == torch.float8_e4m3fn
169+
assert s_d1.dtype == torch.uint8
170+
bytes_r = x.numel() * bytes_per_el_bf16
171+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173+
174+
elif mode == "dim1_mx_triton":
175+
y_d1, s_d1 = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
176+
177+
for _ in range(2):
178+
__ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
179+
time_us = benchmark_cuda_function_in_microseconds(
180+
lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
181+
x,
182+
BLOCK_SIZE,
183+
)
184+
185+
assert y_d1.dtype == torch.float8_e4m3fn
186+
assert s_d1.dtype == torch.float8_e8m0fnu
187+
bytes_r = x.numel() * bytes_per_el_bf16
188+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
189+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
190+
147191
else:
148192
raise AssertionError(f"unknown mode {mode}")
149193

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
get_bits,
2727
pack_uint4,
2828
pack_uint6,
29+
# TODO(before land): better name?
30+
to_mxfp8_dim1,
31+
to_mxfp8_dim1_reference,
2932
triton_f4_to_bf16,
3033
triton_f6_e2m3_to_bf16,
3134
triton_f6_e3m2_to_bf16,
@@ -42,7 +45,11 @@
4245
sem_vals_to_f32,
4346
)
4447
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
48+
from torchao.utils import (
49+
TORCH_VERSION_AT_LEAST_2_8,
50+
is_sm_at_least_89,
51+
is_sm_at_least_100,
52+
)
4653

4754
torch.manual_seed(0)
4855

@@ -444,3 +451,18 @@ def test_fp6_e3m2_pack_unpack():
444451
torch.float32
445452
)
446453
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
454+
455+
456+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
457+
@pytest.mark.skipif(
458+
not is_sm_at_least_89(),
459+
reason="float8 in triton requires CUDA capability 8.9 or greater",
460+
)
461+
@pytest.mark.parametrize("M", (256, 2048))
462+
@pytest.mark.parametrize("K", (256, 2048))
463+
def test_triton_mxfp8_dim1(M, K):
464+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
465+
x_mx_ref, x_s_ref = to_mxfp8_dim1_reference(x, block_size=32)
466+
x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32)
467+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
468+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 257 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
import numpy as np
810
import torch
911
from torch.utils._triton import has_triton
@@ -12,7 +14,7 @@
1214
_f32_to_floatx_unpacked,
1315
_floatx_unpacked_to_f32,
1416
)
15-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
17+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_8
1618

1719
# TODO(future): if needed, make the below work on previous PyTorch versions,
1820
# just need to hunt down the previous location of `libdevice`. An assert
@@ -1080,3 +1082,257 @@ def _(uint8_data):
10801082
def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
10811083
# Dummy placeholder op for torch < 2.4
10821084
raise AssertionError("fp6 packing unsupported without torch >= 2.4")
1085+
1086+
1087+
if TORCH_VERSION_AT_LEAST_2_8 and has_triton():
1088+
import triton
1089+
import triton.language as tl
1090+
1091+
@triton.jit
1092+
def _triton_calculate_scale(x, axis):
1093+
# We use a small epsilon to avoid division by zero
1094+
epsilon = 1e-10
1095+
1096+
# TODO(before land): reuse the constants below instead of hardcoding
1097+
target_max_pow2 = 8
1098+
e8m0_exponent_bias = 127
1099+
# bf16_mbits = 7
1100+
# bf16_exp_bias = 127
1101+
1102+
# Find the maximum absolute value for each row
1103+
max_abs = tl.max(x, axis=axis)
1104+
1105+
# TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files
1106+
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
1107+
# max_abs = max_abs + epsilon
1108+
# max_abs = max_abs.to(tl.bfloat16)
1109+
# max_abs_int16 = max_abs.to(tl.int16, bitcast=True)
1110+
# extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias
1111+
# extracted_pow2 = extracted_pow2 - target_max_pow2
1112+
# scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)
1113+
1114+
# Clamp to exponents that can be represented in e8m0
1115+
scale_e8m0_unbiased = tl.clamp(
1116+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
1117+
)
1118+
1119+
# Create the biased e8m0 representation and cast it to 8 bits
1120+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
1121+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
1122+
1123+
# TODO(future PR): add NaN handling here
1124+
1125+
# For now, calculate the scale in floating point.
1126+
# TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1910/files
1127+
scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32))
1128+
1129+
return scale_fp, scale_e8m0_biased
1130+
1131+
def _get_mxfp8_dim1_kernel_autotune_configs():
1132+
results = []
1133+
for ROW_TILE_SIZE in (64, 128):
1134+
for COL_TILE_SIZE in (64, 128):
1135+
for num_warps in (1, 2, 4):
1136+
config = triton.Config(
1137+
{
1138+
"ROW_TILE_SIZE": ROW_TILE_SIZE,
1139+
"COL_TILE_SIZE": COL_TILE_SIZE,
1140+
},
1141+
num_warps=num_warps,
1142+
)
1143+
results.append(config)
1144+
return results
1145+
1146+
@triton.autotune(
1147+
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1148+
key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"],
1149+
)
1150+
@triton.jit
1151+
def to_mxfp8_dim1_kernel(
1152+
x_ptr, # pointer to input tensor
1153+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
1154+
col_scale_ptr, # pointer to store column-wise maximum absolute values
1155+
n_rows, # number of rows in the tensor
1156+
n_cols, # number of columns in the tensor
1157+
ROW_TILE_SIZE: tl.constexpr, # can be autotuned
1158+
COL_TILE_SIZE: tl.constexpr, # can be autotuned
1159+
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1160+
):
1161+
# TODO(future): better name
1162+
BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE
1163+
1164+
# Get program ID
1165+
pid_row = tl.program_id(0)
1166+
pid_col = tl.program_id(1)
1167+
1168+
# Calculate starting row and column for this tile
1169+
start_row = pid_row * ROW_TILE_SIZE
1170+
start_col = pid_col * COL_TILE_SIZE
1171+
1172+
# Create offsets for the block
1173+
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1174+
col_offsets = tl.arange(0, COL_TILE_SIZE)
1175+
1176+
# Compute global row/col positions
1177+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
1178+
cols = start_col + col_offsets[None, :]
1179+
1180+
# Create masks for out-of-bounds accesses
1181+
row_mask = rows < n_rows
1182+
col_mask = cols < n_cols
1183+
mask = row_mask & col_mask
1184+
1185+
# Compute memory offsets for row-major layout (rows, cols)
1186+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1187+
1188+
# Compute memory offsets for column-major layout (cols, rows)
1189+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
1190+
1191+
# Load the entire block in a single operation
1192+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1193+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
1194+
1195+
# Transpose dim0 and dim1
1196+
# shape: (COL_TILE_SIZE, ROW_TILE_SIZE)
1197+
x_block_t = tl.trans(x_block)
1198+
1199+
# Reshape to inner tile size
1200+
# TODO: make this generic and nice
1201+
# inner_block_size = 32
1202+
# shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE)
1203+
x_block_t_r = x_block_t.reshape(
1204+
COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE
1205+
)
1206+
1207+
# Calculate the absolute values of elements in the block
1208+
x_block_abs_t_r = tl.abs(x_block_t_r)
1209+
1210+
# Find the maximum absolute value for each column
1211+
# shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,)
1212+
col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1)
1213+
1214+
# Divide each column by scale
1215+
# Broadcasting col_scale to match x_block's shape
1216+
# x_block_t shape (COL_TILE_SIZE, ROW_TILE_SIZE)
1217+
# x_block_t_r shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE)
1218+
# col_scale shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, 1)
1219+
col_normalized_t_r = x_block_t_r / col_scale_r[:, None]
1220+
1221+
# Reshape back to original tile size
1222+
col_normalized_t = tl.reshape(col_normalized_t_r, COL_TILE_SIZE, ROW_TILE_SIZE)
1223+
1224+
# Undo the transpose
1225+
col_normalized = tl.trans(col_normalized_t)
1226+
1227+
# Quantize to float8
1228+
col_normalized = col_normalized.to(tl.float8e4nv)
1229+
1230+
# Store the column-normalized result in column-major format
1231+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
1232+
1233+
# reshape col_scale_e8m0_r to col_scale_e8m0
1234+
# shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,)
1235+
# col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE)
1236+
col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE * BLOCKS_PER_ROW_TILE)
1237+
1238+
col_scale_start_offsets = (
1239+
(pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE))
1240+
* BLOCKS_PER_ROW_TILE # number of blocks seen so far
1241+
+ pid_row * BLOCKS_PER_ROW_TILE # increment ROW_TILE_SIZE
1242+
)
1243+
1244+
col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets
1245+
1246+
# calculate col_scale_indices, this is a bit convoluted
1247+
# start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE]
1248+
# from example: [0, 1, 2, 3, 4, 5, 6, 7]
1249+
col_scale_indices = tl.arange(0, COL_TILE_SIZE * BLOCKS_PER_ROW_TILE)
1250+
1251+
# needs better name
1252+
jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE
1253+
1254+
# example transformation (specifics depend on tile sizes):
1255+
# [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13]
1256+
col_scale_indices = col_scale_indices + (
1257+
tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col
1258+
).to(tl.int32)
1259+
1260+
# TODO(future): mask
1261+
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
1262+
1263+
def to_mxfp8_dim1(x, inner_block_size=32):
1264+
"""
1265+
Input:
1266+
* `x` - input tensor, in row major memory layout
1267+
* `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes
1268+
1269+
Output:
1270+
* `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1
1271+
* `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1
1272+
"""
1273+
assert x.is_contiguous(), "`x` must be contiguous"
1274+
assert x.dtype == torch.bfloat16
1275+
assert inner_block_size <= 32
1276+
1277+
# Get tensor shape
1278+
n_rows, n_cols = x.shape
1279+
1280+
# Create output tensors
1281+
output_col_major = torch.empty(
1282+
(n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device
1283+
)
1284+
1285+
# Create scale tensors
1286+
col_scale = torch.empty(
1287+
n_cols, n_rows // inner_block_size, dtype=torch.uint8, device=x.device
1288+
)
1289+
1290+
# Calculate grid dimensions based on tile size
1291+
grid = lambda META: (
1292+
triton.cdiv(n_rows, META["ROW_TILE_SIZE"]),
1293+
triton.cdiv(n_cols, META["COL_TILE_SIZE"]),
1294+
)
1295+
1296+
# Launch the kernel
1297+
to_mxfp8_dim1_kernel[grid](
1298+
x_ptr=x,
1299+
output_col_major_ptr=output_col_major,
1300+
col_scale_ptr=col_scale,
1301+
n_rows=n_rows,
1302+
n_cols=n_cols,
1303+
INNER_BLOCK_SIZE=inner_block_size,
1304+
)
1305+
1306+
return (
1307+
output_col_major.t(),
1308+
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
1309+
)
1310+
1311+
def to_mxfp8_dim1_reference(
1312+
x_hp: torch.Tensor, block_size
1313+
) -> Tuple[torch.Tensor, torch.Tensor]:
1314+
"""
1315+
A reference version of `to_mxfp8_dim1`.
1316+
"""
1317+
from torchao.prototype.mx_formats.mx_tensor import to_mx
1318+
1319+
# cast across dim1
1320+
x_hp_d1 = x_hp.t().contiguous()
1321+
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(
1322+
x_hp_d1, torch.float8_e4m3fn, block_size
1323+
)
1324+
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
1325+
return (
1326+
x_hp_d1_normalized.t(),
1327+
scale_e8m0_dim1,
1328+
)
1329+
1330+
else:
1331+
1332+
def to_mxfp8_across_dim0_and_dim1(x, tile_size=32):
1333+
raise AssertionError("needs torch version 2.8+ and triton")
1334+
1335+
def scale_dim0_dim1_reference(
1336+
x_hp: torch.Tensor, block_size
1337+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1338+
raise AssertionError("needs torch version 2.8+ and triton")

0 commit comments

Comments
 (0)