Skip to content

Commit 31b1fa5

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: fb6ba5e ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent d3b481b commit 31b1fa5

File tree

1 file changed

+382
-0
lines changed

1 file changed

+382
-0
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable, Tuple
13+
14+
import fire
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
from torch._inductor.utils import do_bench_using_profiling
19+
20+
from torchao.prototype.mx_formats.constants import (
21+
E8M0_EXPONENT_BIAS,
22+
F8E4M3_MAX_POW2,
23+
F32_MIN_NORMAL,
24+
)
25+
from torchao.prototype.mx_formats.mx_tensor import to_mx
26+
27+
torch.manual_seed(0)
28+
29+
30+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
31+
"""Thin wrapper around do_bench_using_profiling"""
32+
no_args = lambda: func(*args, **kwargs)
33+
time = do_bench_using_profiling(no_args)
34+
return time * 1e3
35+
36+
37+
def compute_error(x, y):
38+
Ps = torch.linalg.norm(x)
39+
Pn = torch.linalg.norm(x - y)
40+
return 20 * torch.log10(Ps / Pn)
41+
42+
43+
def get_scale_reference(x_hp):
44+
# TODO(before land): reuse code with mx_tensor.py::to_mx
45+
# TODO(future PR): test block of all-zeros
46+
# TODO(future PR): support other rounding modes (currently only supports floor)
47+
48+
# TODO(future PR): support other dtypes
49+
target_max_pow2 = F8E4M3_MAX_POW2
50+
51+
epsilon = 1e-10
52+
max_abs = torch.amax(x_hp, dim=1).unsqueeze(1)
53+
54+
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + epsilon)) - target_max_pow2
55+
56+
# Clamp to exponents that can be represented in e8m0
57+
scale_e8m0_unbiased = torch.clamp(
58+
scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS
59+
)
60+
61+
# Create the biased e8m0 representation and cast it to 8 bits
62+
scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS
63+
scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8)
64+
65+
# TODO(future PR): add NaN handling here
66+
67+
# For now, calculate the scale in floating point.
68+
# TODO(future) audit if there is a need to bit shift exponents instead.
69+
scale_fp = torch.pow(
70+
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
71+
scale_e8m0_unbiased,
72+
)
73+
74+
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
75+
# float32 denormal range. For now, manually adjust the fp scale. This is
76+
# relevant if all of the incoming block values are zeroes.
77+
# See https://github.com/pytorch/pytorch/issues/125557 for details.
78+
# Note: it would be more correct to set the minimum to 2**-127, but this
79+
# does not work in triton either as it looks like subnormal value handling
80+
# has some gaps. So, for now just set to the minimum normal value.
81+
scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL)
82+
83+
return scale_fp, scale_e8m0_biased.view(torch.float8_e8m0fnu)
84+
85+
86+
def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
87+
x_hp_block = x_hp.reshape(-1, block_size)
88+
x_hp_block_abs = x_hp_block.abs()
89+
scale_fp, scale_e8m0 = get_scale_reference(x_hp_block_abs)
90+
x_hp_block_normalized = x_hp_block / scale_fp
91+
x_hp_normalized = x_hp_block_normalized.reshape(x_hp.shape)
92+
return x_hp_normalized.to(x_hp.dtype), scale_e8m0
93+
94+
95+
def scale_dim0_dim1_reference(
96+
x_hp: torch.Tensor, block_size
97+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
98+
# normalize across dim0
99+
# x_hp_d0_normalized, scale_e8m0_dim0 = scale_dim0_reference(x_hp, block_size)
100+
scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size)
101+
x_hp_d0_normalized = x_hp_d0_normalized.bfloat16()
102+
scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu)
103+
104+
# normalize across dim1
105+
x_hp_d1 = x_hp.t().contiguous()
106+
# x_hp_d1_normalized, scale_e8m0_dim1 = scale_dim0_reference(x_hp_d1, block_size)
107+
scale_e8m0_dim1, x_hp_d1_normalized = to_mx(x_hp_d1, torch.float8_e4m3fn, block_size)
108+
x_hp_d1_normalized = x_hp_d1_normalized.bfloat16()
109+
scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu)
110+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), scale_e8m0_dim0, scale_e8m0_dim1
111+
112+
113+
@triton.jit
114+
def _triton_calculate_scale(x, axis):
115+
# We use a small epsilon to avoid division by zero
116+
epsilon = 1e-10
117+
118+
# TODO(before land): reuse the constants below instead of hardcoding
119+
target_max_pow2 = 8
120+
e8m0_exponent_bias = 127
121+
122+
# Find the maximum absolute value for each row
123+
max_abs = tl.max(x, axis=axis)
124+
125+
scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2
126+
127+
# Clamp to exponents that can be represented in e8m0
128+
scale_e8m0_unbiased = tl.clamp(
129+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
130+
)
131+
132+
# Create the biased e8m0 representation and cast it to 8 bits
133+
scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias
134+
scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8)
135+
136+
# TODO(future PR): add NaN handling here
137+
138+
# For now, calculate the scale in floating point.
139+
# TODO(future) audit if there is a need to bit shift exponents instead.
140+
scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32))
141+
142+
return scale_fp, scale_e8m0_biased
143+
144+
145+
@triton.jit
146+
def normalization_kernel(
147+
x_ptr, # pointer to input tensor
148+
output_row_major_ptr, # pointer to row-major output tensor (row-normalized)
149+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
150+
row_scale_ptr, # pointer to store row-wise maximum absolute values
151+
col_scale_ptr, # pointer to store column-wise maximum absolute values
152+
n_rows, # number of rows in the tensor
153+
n_cols, # number of columns in the tensor
154+
TILE_SIZE: tl.constexpr, # tile size as a compile-time constant
155+
):
156+
"""
157+
credit: mostly Claude, some Vasiliy
158+
"""
159+
160+
# Get program ID
161+
pid_row = tl.program_id(0)
162+
pid_col = tl.program_id(1)
163+
164+
# Calculate starting row and column for this tile
165+
start_row = pid_row * TILE_SIZE
166+
start_col = pid_col * TILE_SIZE
167+
168+
# Create offsets for the block
169+
row_offsets = tl.arange(0, TILE_SIZE)
170+
col_offsets = tl.arange(0, TILE_SIZE)
171+
172+
# Compute global row/col positions
173+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
174+
cols = start_col + col_offsets[None, :]
175+
176+
# Create masks for out-of-bounds accesses
177+
row_mask = rows < n_rows
178+
col_mask = cols < n_cols
179+
mask = row_mask & col_mask
180+
181+
# Compute memory offsets for row-major layout (rows, cols)
182+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
183+
184+
# Compute memory offsets for column-major layout (cols, rows)
185+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
186+
187+
# Load the entire block in a single operation
188+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
189+
190+
# ----------------------------------------------------
191+
# Row-wise normalization
192+
# ----------------------------------------------------
193+
# Calculate the absolute values of elements in the block
194+
x_block_abs = tl.abs(x_block)
195+
196+
# Find the maximum absolute value for each row
197+
row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1)
198+
199+
# Normalize each row by its maximum absolute value
200+
# Broadcasting row_scale to match x_block's shape
201+
row_normalized = x_block / row_scale[:, None]
202+
203+
# fake quant to float8
204+
row_normalized = row_normalized.to(tl.float8e4nv)
205+
row_normalized = row_normalized.to(tl.bfloat16)
206+
207+
# ----------------------------------------------------
208+
# Column-wise normalization
209+
# ----------------------------------------------------
210+
# Find the maximum absolute value for each column
211+
col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0)
212+
213+
# Normalize each column by its maximum absolute value
214+
# Broadcasting col_scale to match x_block's shape
215+
col_normalized = x_block / col_scale[None, :]
216+
217+
# fake quant to float8
218+
col_normalized = col_normalized.to(tl.float8e4nv)
219+
col_normalized = col_normalized.to(tl.bfloat16)
220+
221+
# Store the row-normalized result in row-major format
222+
tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask)
223+
224+
# Store the column-normalized result in column-major format
225+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
226+
227+
# Create 1D ranges for storing row and column max values
228+
row_indices = start_row + tl.arange(0, TILE_SIZE)
229+
col_indices = start_col + tl.arange(0, TILE_SIZE)
230+
231+
# Create masks for valid rows and columns
232+
row_mask = row_indices < n_rows
233+
col_mask = col_indices < n_cols
234+
235+
# Vasiliy - deviating from Claude here for much simpler code
236+
row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col
237+
row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE)
238+
# TODO(future): mask
239+
tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0)
240+
241+
# Vasiliy - deviating from Claude here for much simpler code
242+
col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row
243+
col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE)
244+
# TODO(future): mask
245+
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
246+
247+
248+
# Function to launch the kernel
249+
def normalize_tiled(x, tile_size=32):
250+
# Get tensor shape
251+
n_rows, n_cols = x.shape
252+
253+
# Create output tensors (both row-major and column-major)
254+
output_row_major = torch.empty_like(x)
255+
output_col_major = torch.empty((n_cols, n_rows), dtype=x.dtype, device=x.device)
256+
257+
# Create tensors for row-wise and column-wise maximum absolute values
258+
row_scale = torch.empty(
259+
n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device
260+
)
261+
col_scale = torch.empty(
262+
n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device
263+
)
264+
265+
# Calculate grid dimensions based on tile size
266+
grid_rows = triton.cdiv(n_rows, tile_size)
267+
grid_cols = triton.cdiv(n_cols, tile_size)
268+
269+
# Launch the kernel
270+
normalization_kernel[(grid_rows, grid_cols)](
271+
x_ptr=x,
272+
output_row_major_ptr=output_row_major,
273+
output_col_major_ptr=output_col_major,
274+
row_scale_ptr=row_scale,
275+
col_scale_ptr=col_scale,
276+
n_rows=n_rows,
277+
n_cols=n_cols,
278+
TILE_SIZE=tile_size,
279+
)
280+
281+
return (
282+
output_row_major,
283+
output_col_major.t(),
284+
row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
285+
col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu),
286+
)
287+
288+
289+
def run(
290+
M: int = 4096,
291+
K: int = 2048,
292+
BLOCK_SIZE: int = 32,
293+
):
294+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
295+
print(f"GPU: {torch.cuda.get_device_name(0)}")
296+
print(f"torch version: {torch.__version__}")
297+
print(f"triton version: {triton.__version__}")
298+
299+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
300+
301+
scale_dim0_dim1_c = torch.compile(scale_dim0_dim1_reference)
302+
303+
# reference implementation (plain PyTorch + torch.compile)
304+
x_d0, x_d1, scale_e8m0_d0, scale_e8m0_d1 = scale_dim0_dim1_c(x, BLOCK_SIZE)
305+
scale_fp_d0 = scale_e8m0_d0.float()
306+
scale_fp_d1 = scale_e8m0_d1.float()
307+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * scale_fp_d0).reshape(x_d0.shape)
308+
x_d1_and_back = (
309+
(x_d1.t().reshape(-1, BLOCK_SIZE) * scale_fp_d1).reshape(x_d1.t().shape).t()
310+
)
311+
312+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
313+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
314+
print(
315+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
316+
)
317+
assert (
318+
sqnr_bf16_vs_dim0_ref > 28 and sqnr_bf16_vs_dim1_ref > 28
319+
), "reference mx numerics are incorrect"
320+
321+
# basic triton kernel
322+
x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t = normalize_tiled(x, tile_size=BLOCK_SIZE)
323+
324+
# ensure bitwise equivalency of outputs with reference
325+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
326+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
327+
torch.testing.assert_close(scale_e8m0_d0, scale_e8m0_d0_t, atol=0, rtol=0)
328+
torch.testing.assert_close(scale_e8m0_d1, scale_e8m0_d1_t, atol=0, rtol=0)
329+
print("normalized reference vs normalized triton are bitwise equivalent")
330+
331+
if False:
332+
# for debugging
333+
sqnr_x_d0_ref_vs_t = compute_error(x_d0, x_d0_t)
334+
print("sqnr_x_d0_t", sqnr_x_d0_ref_vs_t)
335+
sqnr_scale_e8m0_d0_vs_t = compute_error(scale_e8m0_d0, scale_e8m0_d0_t)
336+
print("sqnr_scale_e8m0_d0_t", sqnr_scale_e8m0_d0_vs_t)
337+
sqnr_x_d1_ref_vs_t = compute_error(x_d1, x_d1_t)
338+
print("sqnr_x_d1_t", sqnr_x_d1_ref_vs_t)
339+
sqnr_scale_e8m0_d1_vs_t = compute_error(scale_e8m0_d1, scale_e8m0_d1_t)
340+
print("sqnr_scale_e8m0_d1_t", sqnr_scale_e8m0_d1_vs_t)
341+
342+
# now, measure performance
343+
344+
# warm up
345+
for _ in range(2):
346+
__ = scale_dim0_dim1_reference(x, BLOCK_SIZE)
347+
time_reference_compile_us = benchmark_cuda_function_in_microseconds(
348+
lambda x, b: scale_dim0_dim1_c(x, b), x, BLOCK_SIZE
349+
)
350+
351+
# warm up
352+
for _ in range(2):
353+
__ = normalize_tiled(x, tile_size=BLOCK_SIZE)
354+
time_triton_us = benchmark_cuda_function_in_microseconds(
355+
lambda x, b: normalize_tiled(x, tile_size=BLOCK_SIZE), x, BLOCK_SIZE
356+
)
357+
358+
# calculate bytes read/written
359+
bytes_per_el_bf16 = 2
360+
bytes_per_el_fp8 = 1
361+
triton_bytes_read = x.numel() * bytes_per_el_bf16
362+
triton_bytes_written = (
363+
sum(x.numel() for x in (x_d0_t, x_d1_t, scale_e8m0_d0_t, scale_e8m0_d1_t))
364+
* bytes_per_el_fp8
365+
)
366+
triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / (
367+
time_triton_us / 1e6
368+
)
369+
# TODO read 8.0 TB/s number from roofline_utils.py instead of hardcoding
370+
triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12
371+
372+
print("time_reference_compile_us", time_reference_compile_us)
373+
print("time_triton_us", time_triton_us)
374+
print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps)
375+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
376+
# can achieve around 50-70% of B200 peak mem bw
377+
print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw)
378+
print("speedup", time_reference_compile_us / time_triton_us)
379+
380+
381+
if __name__ == "__main__":
382+
fire.Fire(run)

0 commit comments

Comments
 (0)