Skip to content

Commit abf4c7f

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: c7ddc8e ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent d3b481b commit abf4c7f

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable, Tuple
13+
14+
import fire
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
from torch._inductor.utils import do_bench_using_profiling
19+
20+
from torchao.prototype.mx_formats.constants import (
21+
E8M0_EXPONENT_BIAS,
22+
F8E4M3_MAX_POW2,
23+
F32_MIN_NORMAL,
24+
)
25+
26+
torch.manual_seed(0)
27+
28+
29+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
30+
"""Thin wrapper around do_bench_using_profiling"""
31+
no_args = lambda: func(*args, **kwargs)
32+
time = do_bench_using_profiling(no_args)
33+
return time * 1e3
34+
35+
36+
def compute_error(x, y):
37+
Ps = torch.linalg.norm(x)
38+
Pn = torch.linalg.norm(x - y)
39+
return 20 * torch.log10(Ps / Pn)
40+
41+
42+
def get_scale_reference(x_hp):
43+
# TODO(before land): reuse code with mx_tensor.py::to_mx
44+
# TODO(future PR): test block of all-zeros
45+
# TODO(future PR): support other rounding modes (currently only supports floor)
46+
47+
# TODO(future PR): support other dtypes
48+
target_max_pow2 = F8E4M3_MAX_POW2
49+
50+
epsilon = 1e-10
51+
max_abs = torch.amax(x_hp, dim=1).unsqueeze(1)
52+
53+
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + epsilon)) - target_max_pow2
54+
55+
# Clamp to exponents that can be represented in e8m0
56+
scale_e8m0_unbiased = torch.clamp(
57+
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS
58+
)
59+
60+
# Create the biased e8m0 representation and cast it to 8 bits
61+
scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS
62+
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
63+
64+
# TODO(future PR): add NaN handling here
65+
66+
# For now, calculate the scale in floating point.
67+
# TODO(future) audit if there is a need to bit shift exponents instead.
68+
scale_fp = torch.pow(
69+
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
70+
scale_e8m0_unbiased,
71+
)
72+
73+
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
74+
# float32 denormal range. For now, manually adjust the fp scale. This is
75+
# relevant if all of the incoming block values are zeroes.
76+
# See https://github.com/pytorch/pytorch/issues/125557 for details.
77+
# Note: it would be more correct to set the minimum to 2**-127, but this
78+
# does not work in triton either as it looks like subnormal value handling
79+
# has some gaps. So, for now just set to the minimum normal value.
80+
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)
81+
82+
return scale_fp, scale_e8m0_biased.view(torch.float8_e8m0fnu)
83+
84+
85+
def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
86+
x_hp_block = x_hp.reshape(-1, block_size)
87+
x_hp_block_abs = x_hp_block.abs()
88+
scale_fp, scale_e8m0 = get_scale_reference(x_hp_block_abs)
89+
x_hp_block_normalized = x_hp_block / scale_fp
90+
x_hp_normalized = x_hp_block_normalized.reshape(x_hp.shape)
91+
return x_hp_normalized.to(x_hp.dtype), scale_e8m0
92+
93+
94+
def scale_dim0_dim1_reference(
95+
x_hp: torch.Tensor, block_size
96+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
97+
# normalize across dim0
98+
x_hp_d0_normalized, scale_e8m0_dim0 = scale_dim0_reference(x_hp, block_size)
99+
# normalize across dim1
100+
x_hp_d1 = x_hp.t().contiguous()
101+
x_hp_d1_normalized, scale_e8m0_dim1 = scale_dim0_reference(x_hp_d1, block_size)
102+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), scale_e8m0_dim0, scale_e8m0_dim1
103+
104+
105+
@triton.jit
106+
def _triton_calculate_scale(x, axis):
107+
# We use a small epsilon to avoid division by zero
108+
epsilon = 1e-10
109+
110+
# TODO(before land): reuse the constants below instead of hardcoding
111+
target_max_pow2 = 8
112+
e8m0_exponent_bias = 127
113+
114+
# Find the maximum absolute value for each row
115+
max_abs = tl.max(x, axis=axis)
116+
117+
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
118+
119+
# Clamp to exponents that can be represented in e8m0
120+
scale_e8m0_unbiased = tl.clamp(
121+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
122+
)
123+
124+
# Create the biased e8m0 representation and cast it to 8 bits
125+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
126+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
127+
128+
# TODO(future PR): add NaN handling here
129+
130+
# For now, calculate the scale in floating point.
131+
# TODO(future) audit if there is a need to bit shift exponents instead.
132+
scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32))
133+
134+
return scale_fp, scale_e8m0_biased
135+
136+
137+
@triton.jit
138+
def normalization_kernel(
139+
x_ptr, # pointer to input tensor
140+
output_row_major_ptr, # pointer to row-major output tensor (row-normalized)
141+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
142+
row_scale_ptr, # pointer to store row-wise maximum absolute values
143+
col_scale_ptr, # pointer to store column-wise maximum absolute values
144+
n_rows, # number of rows in the tensor
145+
n_cols, # number of columns in the tensor
146+
TILE_SIZE: tl.constexpr, # tile size as a compile-time constant
147+
):
148+
"""
149+
credit: mostly Claude, some Vasiliy
150+
"""
151+
152+
# Get program ID
153+
pid_row = tl.program_id(0)
154+
pid_col = tl.program_id(1)
155+
156+
# Calculate starting row and column for this tile
157+
start_row = pid_row * TILE_SIZE
158+
start_col = pid_col * TILE_SIZE
159+
160+
# Create offsets for the block
161+
row_offsets = tl.arange(0, TILE_SIZE)
162+
col_offsets = tl.arange(0, TILE_SIZE)
163+
164+
# Compute global row/col positions
165+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
166+
cols = start_col + col_offsets[None, :]
167+
168+
# Create masks for out-of-bounds accesses
169+
row_mask = rows < n_rows
170+
col_mask = cols < n_cols
171+
mask = row_mask & col_mask
172+
173+
# Compute memory offsets for row-major layout (rows, cols)
174+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
175+
176+
# Compute memory offsets for column-major layout (cols, rows)
177+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
178+
179+
# Load the entire block in a single operation
180+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
181+
182+
# ----------------------------------------------------
183+
# Row-wise normalization
184+
# ----------------------------------------------------
185+
# Calculate the absolute values of elements in the block
186+
x_block_abs = tl.abs(x_block)
187+
188+
# Find the maximum absolute value for each row
189+
row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1)
190+
191+
# Normalize each row by its maximum absolute value
192+
# Broadcasting row_scale to match x_block's shape
193+
row_normalized = x_block / row_scale[:, None]
194+
195+
# ----------------------------------------------------
196+
# Column-wise normalization
197+
# ----------------------------------------------------
198+
# Find the maximum absolute value for each column
199+
col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0)
200+
201+
# Normalize each column by its maximum absolute value
202+
# Broadcasting col_scale to match x_block's shape
203+
col_normalized = x_block / col_scale[None, :]
204+
205+
# Store the row-normalized result in row-major format
206+
tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask)
207+
208+
# Store the column-normalized result in column-major format
209+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
210+
211+
# Create 1D ranges for storing row and column max values
212+
row_indices = start_row + tl.arange(0, TILE_SIZE)
213+
col_indices = start_col + tl.arange(0, TILE_SIZE)
214+
215+
# Create masks for valid rows and columns
216+
row_mask = row_indices < n_rows
217+
col_mask = col_indices < n_cols
218+
219+
# Vasiliy - deviating from Claude here for much simpler code
220+
row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col
221+
row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE)
222+
# TODO(future): mask
223+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
224+
225+
# Vasiliy - deviating from Claude here for much simpler code
226+
col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row
227+
col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE)
228+
# TODO(future): mask
229+
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
230+
231+
232+
# Function to launch the kernel
233+
def normalize_tiled(x, tile_size=32):
234+
# Get tensor shape
235+
n_rows, n_cols = x.shape
236+
237+
# Create output tensors (both row-major and column-major)
238+
output_row_major = torch.empty_like(x)
239+
output_col_major = torch.empty((n_cols, n_rows), dtype=x.dtype, device=x.device)
240+
241+
# Create tensors for row-wise and column-wise maximum absolute values
242+
row_scale = torch.empty(
243+
n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device
244+
)
245+
col_scale = torch.empty(
246+
n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device
247+
)
248+
249+
# Calculate grid dimensions based on tile size
250+
grid_rows = triton.cdiv(n_rows, tile_size)
251+
grid_cols = triton.cdiv(n_cols, tile_size)
252+
253+
# Launch the kernel
254+
normalization_kernel[(grid_rows, grid_cols)](
255+
x_ptr=x,
256+
output_row_major_ptr=output_row_major,
257+
output_col_major_ptr=output_col_major,
258+
row_scale_ptr=row_scale,
259+
col_scale_ptr=col_scale,
260+
n_rows=n_rows,
261+
n_cols=n_cols,
262+
TILE_SIZE=tile_size,
263+
)
264+
265+
return (
266+
output_row_major,
267+
output_col_major.t(),
268+
row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
269+
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
270+
)
271+
272+
273+
def run(
274+
M: int = 4096,
275+
K: int = 2048,
276+
BLOCK_SIZE: int = 32,
277+
):
278+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
279+
print(f"GPU: {torch.cuda.get_device_name(0)}")
280+
print(f"torch version: {torch.__version__}")
281+
print(f"triton version: {triton.__version__}")
282+
283+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
284+
285+
scale_dim0_dim1_c = torch.compile(scale_dim0_dim1_reference)
286+
287+
# reference implementation (plain PyTorch + torch.compile)
288+
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = scale_dim0_dim1_c(x, BLOCK_SIZE)
289+
scale_fp_d0 = scale_e8m0_d0.float()
290+
scale_fp_d1 = scale_e8m0_d1.float()
291+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
292+
x_d1_and_back = (
293+
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
294+
)
295+
296+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
297+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
298+
print(
299+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
300+
)
301+
assert (
302+
sqnr_bf16_vs_dim0_ref > 50 and sqnr_bf16_vs_dim1_ref > 50
303+
), "reference normlization numerics are incorrect"
304+
305+
# basic triton kernel
306+
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = normalize_tiled(x, tile_size=BLOCK_SIZE)
307+
308+
# ensure bitwise equivalency of outputs with reference
309+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
310+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
311+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
312+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
313+
print("normalized reference vs normalized triton are bitwise equivalent")
314+
315+
if False:
316+
# for debugging
317+
sqnr_x_d0_ref_vs_t = compute_error(x_d0, x_d0_t)
318+
print("sqnr_x_d0_t", sqnr_x_d0_ref_vs_t)
319+
sqnr_scale_e8m0_d0_vs_t = compute_error(scale_e8m0_d0, scale_e8m0_d0_t)
320+
print("sqnr_scale_e8m0_d0_t", sqnr_scale_e8m0_d0_vs_t)
321+
sqnr_x_d1_ref_vs_t = compute_error(x_d1, x_d1_t)
322+
print("sqnr_x_d1_t", sqnr_x_d1_ref_vs_t)
323+
sqnr_scale_e8m0_d1_vs_t = compute_error(scale_e8m0_d1, scale_e8m0_d1_t)
324+
print("sqnr_scale_e8m0_d1_t", sqnr_scale_e8m0_d1_vs_t)
325+
326+
# now, measure performance
327+
328+
# warm up
329+
for _ in range(2):
330+
__ = scale_dim0_dim1_reference(x, BLOCK_SIZE)
331+
time_reference_compile_us = benchmark_cuda_function_in_microseconds(
332+
lambda x, b: scale_dim0_dim1_c(x, b), x, BLOCK_SIZE
333+
)
334+
335+
# warm up
336+
for _ in range(2):
337+
__ = normalize_tiled(x, tile_size=BLOCK_SIZE)
338+
time_triton_us = benchmark_cuda_function_in_microseconds(
339+
lambda x, b: normalize_tiled(x, tile_size=BLOCK_SIZE), x, BLOCK_SIZE
340+
)
341+
342+
# calculate bytes read/written
343+
bytes_per_el_bf16 = 2
344+
bytes_per_el_fp8 = 1
345+
triton_bytes_read = x.numel() * bytes_per_el_bf16
346+
triton_bytes_written = (
347+
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
348+
* bytes_per_el_fp8
349+
)
350+
triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / (
351+
time_triton_us / 1e6
352+
)
353+
# TODO read 8.0 TB/s number from roofline_utils.py instead of hardcoding
354+
triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12
355+
356+
print("time_reference_compile_us", time_reference_compile_us)
357+
print("time_triton_us", time_triton_us)
358+
print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps)
359+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
360+
# can achieve around 50-70% of B200 peak mem bw
361+
print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw)
362+
print("speedup", time_reference_compile_us / time_triton_us)
363+
364+
365+
if __name__ == "__main__":
366+
fire.Fire(run)

0 commit comments

Comments
 (0)