Skip to content

Commit 9c8b303

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 322d509 ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent d3b481b commit 9c8b303

File tree

1 file changed

+366
-0
lines changed

1 file changed

+366
-0
lines changed
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable, Tuple
13+
14+
import fire
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
from torch._inductor.utils import do_bench_using_profiling
19+
20+
from torchao.prototype.mx_formats.constants import (
21+
E8M0_EXPONENT_BIAS,
22+
F8E4M3_MAX_POW2,
23+
F32_MIN_NORMAL,
24+
)
25+
26+
torch.manual_seed(0)
27+
28+
29+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
30+
"""Thin wrapper around do_bench_using_profiling"""
31+
no_args = lambda: func(*args, **kwargs)
32+
time = do_bench_using_profiling(no_args)
33+
return time * 1e3
34+
35+
36+
def compute_error(x, y):
37+
Ps = torch.linalg.norm(x)
38+
Pn = torch.linalg.norm(x - y)
39+
return 20 * torch.log10(Ps / Pn)
40+
41+
42+
def get_scale_reference(x_hp):
43+
# TODO(before land): reuse code with mx_tensor.py::to_mx
44+
# TODO(future PR): test block of all-zeros
45+
# TODO(future PR): support other rounding modes (currently only supports floor)
46+
47+
# TODO(future PR): support other dtypes
48+
target_max_pow2 = F8E4M3_MAX_POW2
49+
50+
epsilon = 1e-10
51+
max_abs = torch.amax(x_hp, dim=1).unsqueeze(1)
52+
53+
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + epsilon)) - target_max_pow2
54+
55+
# Clamp to exponents that can be represented in e8m0
56+
scale_e8m0_unbiased = torch.clamp(
57+
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS
58+
)
59+
60+
# Create the biased e8m0 representation and cast it to 8 bits
61+
scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS
62+
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
63+
64+
# TODO(future PR): add NaN handling here
65+
66+
# For now, calculate the scale in floating point.
67+
# TODO(future) audit if there is a need to bit shift exponents instead.
68+
scale_fp = torch.pow(
69+
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
70+
scale_e8m0_unbiased,
71+
)
72+
73+
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
74+
# float32 denormal range. For now, manually adjust the fp scale. This is
75+
# relevant if all of the incoming block values are zeroes.
76+
# See https://github.com/pytorch/pytorch/issues/125557 for details.
77+
# Note: it would be more correct to set the minimum to 2**-127, but this
78+
# does not work in triton either as it looks like subnormal value handling
79+
# has some gaps. So, for now just set to the minimum normal value.
80+
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)
81+
82+
# return max_abs
83+
return scale_fp
84+
85+
86+
def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
87+
x_hp_block = x_hp.reshape(-1, block_size)
88+
x_hp_block_abs = x_hp_block.abs()
89+
scale = get_scale_reference(x_hp_block_abs)
90+
x_hp_block_normalized = x_hp_block / scale
91+
x_hp_normalized = x_hp_block_normalized.reshape(x_hp.shape)
92+
return x_hp_normalized.to(x_hp.dtype), scale
93+
94+
95+
def scale_dim0_dim1_reference(
96+
x_hp: torch.Tensor, block_size
97+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
98+
# normalize across dim0
99+
x_hp_d0_normalized, amax_dim0 = scale_dim0_reference(x_hp, block_size)
100+
# normalize across dim1
101+
x_hp_d1 = x_hp.t().contiguous()
102+
x_hp_d1_normalized, amax_dim1 = scale_dim0_reference(x_hp_d1, block_size)
103+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
104+
105+
106+
@triton.jit
107+
def _triton_calculate_scale(x, axis):
108+
# We use a small epsilon to avoid division by zero
109+
epsilon = 1e-10
110+
111+
# TODO(before land): reuse the constants below instead of hardcoding
112+
target_max_pow2 = 8
113+
e8m0_exponent_bias = 127
114+
115+
# Find the maximum absolute value for each row
116+
max_abs = tl.max(x, axis=axis)
117+
118+
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
119+
120+
# Clamp to exponents that can be represented in e8m0
121+
scale_e8m0_unbiased = tl.clamp(
122+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
123+
)
124+
125+
# Create the biased e8m0 representation and cast it to 8 bits
126+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
127+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
128+
129+
# TODO(future PR): add NaN handling here
130+
131+
# For now, calculate the scale in floating point.
132+
# TODO(future) audit if there is a need to bit shift exponents instead.
133+
scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32))
134+
135+
return scale_fp
136+
137+
138+
@triton.jit
139+
def normalization_kernel(
140+
x_ptr, # pointer to input tensor
141+
output_row_major_ptr, # pointer to row-major output tensor (row-normalized)
142+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
143+
row_scale_ptr, # pointer to store row-wise maximum absolute values
144+
col_scale_ptr, # pointer to store column-wise maximum absolute values
145+
n_rows, # number of rows in the tensor
146+
n_cols, # number of columns in the tensor
147+
TILE_SIZE: tl.constexpr, # tile size as a compile-time constant
148+
):
149+
"""
150+
credit: mostly Claude, some Vasiliy
151+
"""
152+
153+
# Get program ID
154+
pid_row = tl.program_id(0)
155+
pid_col = tl.program_id(1)
156+
157+
# Calculate starting row and column for this tile
158+
start_row = pid_row * TILE_SIZE
159+
start_col = pid_col * TILE_SIZE
160+
161+
# Create offsets for the block
162+
row_offsets = tl.arange(0, TILE_SIZE)
163+
col_offsets = tl.arange(0, TILE_SIZE)
164+
165+
# Compute global row/col positions
166+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
167+
cols = start_col + col_offsets[None, :]
168+
169+
# Create masks for out-of-bounds accesses
170+
row_mask = rows < n_rows
171+
col_mask = cols < n_cols
172+
mask = row_mask & col_mask
173+
174+
# Compute memory offsets for row-major layout (rows, cols)
175+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
176+
177+
# Compute memory offsets for column-major layout (cols, rows)
178+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
179+
180+
# Load the entire block in a single operation
181+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
182+
183+
# ----------------------------------------------------
184+
# Row-wise normalization
185+
# ----------------------------------------------------
186+
# Calculate the absolute values of elements in the block
187+
x_block_abs = tl.abs(x_block)
188+
189+
# Find the maximum absolute value for each row
190+
row_scale = _triton_calculate_scale(x_block_abs, axis=1)
191+
192+
# Normalize each row by its maximum absolute value
193+
# Broadcasting row_scale to match x_block's shape
194+
row_normalized = x_block / row_scale[:, None]
195+
196+
# ----------------------------------------------------
197+
# Column-wise normalization
198+
# ----------------------------------------------------
199+
# Find the maximum absolute value for each column
200+
col_scale = _triton_calculate_scale(x_block_abs, axis=0)
201+
202+
# Normalize each column by its maximum absolute value
203+
# Broadcasting col_scale to match x_block's shape
204+
col_normalized = x_block / col_scale[None, :]
205+
206+
# Store the row-normalized result in row-major format
207+
tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask)
208+
209+
# Store the column-normalized result in column-major format
210+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
211+
212+
# Create 1D ranges for storing row and column max values
213+
row_indices = start_row + tl.arange(0, TILE_SIZE)
214+
col_indices = start_col + tl.arange(0, TILE_SIZE)
215+
216+
# Create masks for valid rows and columns
217+
row_mask = row_indices < n_rows
218+
col_mask = col_indices < n_cols
219+
220+
# Vasiliy - deviating from Claude here for much simpler code
221+
row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col
222+
row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE)
223+
# TODO(future): mask
224+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale)
225+
226+
# Vasiliy - deviating from Claude here for much simpler code
227+
col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row
228+
col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE)
229+
# TODO(future): mask
230+
tl.store(col_scale_start_ptr + col_scale_indices, col_scale)
231+
232+
233+
# Function to launch the kernel
234+
def normalize_tiled(x, tile_size=32):
235+
# Get tensor shape
236+
n_rows, n_cols = x.shape
237+
238+
# Create output tensors (both row-major and column-major)
239+
output_row_major = torch.empty_like(x)
240+
output_col_major = torch.empty((n_cols, n_rows), dtype=x.dtype, device=x.device)
241+
242+
# Create tensors for row-wise and column-wise maximum absolute values
243+
row_scale = torch.empty(
244+
n_rows, n_cols // tile_size, dtype=torch.float, device=x.device
245+
)
246+
col_scale = torch.empty(
247+
n_cols, n_rows // tile_size, dtype=torch.float, device=x.device
248+
)
249+
250+
# Calculate grid dimensions based on tile size
251+
grid_rows = triton.cdiv(n_rows, tile_size)
252+
grid_cols = triton.cdiv(n_cols, tile_size)
253+
254+
# Launch the kernel
255+
normalization_kernel[(grid_rows, grid_cols)](
256+
x_ptr=x,
257+
output_row_major_ptr=output_row_major,
258+
output_col_major_ptr=output_col_major,
259+
row_scale_ptr=row_scale,
260+
col_scale_ptr=col_scale,
261+
n_rows=n_rows,
262+
n_cols=n_cols,
263+
TILE_SIZE=tile_size,
264+
)
265+
266+
return (
267+
output_row_major,
268+
output_col_major.t(),
269+
row_scale.reshape(-1, 1),
270+
col_scale.reshape(-1, 1),
271+
)
272+
273+
274+
def run(
275+
M: int = 4096,
276+
K: int = 2048,
277+
BLOCK_SIZE: int = 32,
278+
):
279+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
280+
print(f"GPU: {torch.cuda.get_device_name(0)}")
281+
print(f"torch version: {torch.__version__}")
282+
print(f"triton version: {triton.__version__}")
283+
284+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
285+
286+
scale_dim0_dim1_c = torch.compile(scale_dim0_dim1_reference)
287+
288+
# reference implementation (plain PyTorch + torch.compile)
289+
x_d0, x_d1, amax_d0, amax_d1 = scale_dim0_dim1_c(x, BLOCK_SIZE)
290+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * amax_d0).reshape(x_d0.shape)
291+
x_d1_and_back = (
292+
(x_d1.t().reshape(-1, BLOCK_SIZE) * amax_d1).reshape(x_d1.t().shape).t()
293+
)
294+
295+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
296+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
297+
print(
298+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
299+
)
300+
assert (
301+
sqnr_bf16_vs_dim0_ref > 50 and sqnr_bf16_vs_dim1_ref > 50
302+
), "reference normlization numerics are incorrect"
303+
304+
# basic triton kernel
305+
x_d0_t, x_d1_t, amax_d0_t, amax_d1_t = normalize_tiled(x, tile_size=BLOCK_SIZE)
306+
307+
# ensure bitwise equivalency of outputs with reference
308+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
309+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
310+
print(1, amax_d0.dtype, 2, amax_d0_t.dtype)
311+
torch.testing.assert_close(amax_d0, amax_d0_t, atol=0, rtol=0)
312+
torch.testing.assert_close(amax_d1, amax_d1_t, atol=0, rtol=0)
313+
print("normalized reference vs normalized triton are bitwise equivalent")
314+
315+
if False:
316+
# for debugging
317+
sqnr_x_d0_ref_vs_t = compute_error(x_d0, x_d0_t)
318+
print("sqnr_x_d0_t", sqnr_x_d0_ref_vs_t)
319+
sqnr_amax_d0_vs_t = compute_error(amax_d0, amax_d0_t)
320+
print("sqnr_amax_d0_t", sqnr_amax_d0_vs_t)
321+
sqnr_x_d1_ref_vs_t = compute_error(x_d1, x_d1_t)
322+
print("sqnr_x_d1_t", sqnr_x_d1_ref_vs_t)
323+
sqnr_amax_d1_vs_t = compute_error(amax_d1, amax_d1_t)
324+
print("sqnr_amax_d1_t", sqnr_amax_d1_vs_t)
325+
326+
# now, measure performance
327+
328+
# warm up
329+
for _ in range(2):
330+
__ = scale_dim0_dim1_reference(x, BLOCK_SIZE)
331+
time_reference_compile_us = benchmark_cuda_function_in_microseconds(
332+
lambda x, b: scale_dim0_dim1_c(x, b), x, BLOCK_SIZE
333+
)
334+
335+
# warm up
336+
for _ in range(2):
337+
__ = normalize_tiled(x, tile_size=BLOCK_SIZE)
338+
time_triton_us = benchmark_cuda_function_in_microseconds(
339+
lambda x, b: normalize_tiled(x, tile_size=BLOCK_SIZE), x, BLOCK_SIZE
340+
)
341+
342+
# calculate bytes read/written
343+
bytes_per_el_bf16 = 2
344+
bytes_per_el_fp8 = 1
345+
triton_bytes_read = x.numel() * bytes_per_el_bf16
346+
triton_bytes_written = (
347+
sum(x.numel() for x in (x_d0_t, x_d1_t, amax_d0_t, amax_d1_t))
348+
* bytes_per_el_fp8
349+
)
350+
triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / (
351+
time_triton_us / 1e6
352+
)
353+
# TODO read 8.0 TB/s number from roofline_utils.py instead of hardcoding
354+
triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12
355+
356+
print("time_reference_compile_us", time_reference_compile_us)
357+
print("time_triton_us", time_triton_us)
358+
print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps)
359+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
360+
# can achieve around 50-70% of B200 peak mem bw
361+
print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw)
362+
print("speedup", time_reference_compile_us / time_triton_us)
363+
364+
365+
if __name__ == "__main__":
366+
fire.Fire(run)

0 commit comments

Comments
 (0)