Skip to content

Commit 6a8cf6d

Browse files
committed
[wip] triton kernel to cast to mx across dim0 and dim1
Summary: Test Plan: ``` python torchao/prototype/mx_formats/mx_dim0_dim1_cast.py ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 3da755e ghstack-comment-id: 2714865161 Pull Request resolved: #1869
1 parent d3b481b commit 6a8cf6d

File tree

1 file changed

+286
-0
lines changed

1 file changed

+286
-0
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py
9+
and making it nice.
10+
"""
11+
12+
from typing import Callable, Tuple
13+
14+
import fire
15+
import torch
16+
import triton
17+
import triton.language as tl
18+
from torch._inductor.utils import do_bench_using_profiling
19+
20+
torch.manual_seed(0)
21+
22+
23+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
24+
"""Thin wrapper around do_bench_using_profiling"""
25+
no_args = lambda: func(*args, **kwargs)
26+
time = do_bench_using_profiling(no_args)
27+
return time * 1e3
28+
29+
30+
def compute_error(x, y):
31+
Ps = torch.linalg.norm(x)
32+
Pn = torch.linalg.norm(x - y)
33+
return 20 * torch.log10(Ps / Pn)
34+
35+
36+
def scale_dim0_dim1_reference(
37+
x_hp: torch.Tensor, block_size
38+
) -> Tuple[torch.Tensor, torch.Tensor]:
39+
# normalize across dim0
40+
x_hp_d0_block = x_hp.reshape(-1, block_size)
41+
x_hp_d0_block_abs = x_hp_d0_block.abs()
42+
amax_dim0 = torch.amax(x_hp_d0_block_abs, dim=1).unsqueeze(1)
43+
x_hp_d0_block_normalized = x_hp_d0_block / amax_dim0
44+
x_hp_d0_normalized = x_hp_d0_block_normalized.reshape(x_hp.shape)
45+
46+
# normalize across dim1
47+
x_hp_d1 = x_hp.t().contiguous()
48+
x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
49+
x_hp_d1_block_abs = x_hp_d1_block.abs()
50+
amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
51+
x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
52+
x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
53+
54+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
55+
56+
57+
@triton.jit
58+
def normalization_kernel(
59+
x_ptr, # pointer to input tensor
60+
output_row_major_ptr, # pointer to row-major output tensor (row-normalized)
61+
output_col_major_ptr, # pointer to column-major output tensor (column-normalized)
62+
row_max_abs_ptr, # pointer to store row-wise maximum absolute values
63+
col_max_abs_ptr, # pointer to store column-wise maximum absolute values
64+
n_rows, # number of rows in the tensor
65+
n_cols, # number of columns in the tensor
66+
TILE_SIZE: tl.constexpr, # tile size as a compile-time constant
67+
):
68+
"""
69+
credit: mostly Claude, some Vasiliy
70+
"""
71+
72+
# Get program ID
73+
pid_row = tl.program_id(0)
74+
pid_col = tl.program_id(1)
75+
76+
# Calculate starting row and column for this tile
77+
start_row = pid_row * TILE_SIZE
78+
start_col = pid_col * TILE_SIZE
79+
80+
# Create offsets for the block
81+
row_offsets = tl.arange(0, TILE_SIZE)
82+
col_offsets = tl.arange(0, TILE_SIZE)
83+
84+
# Compute global row/col positions
85+
rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting
86+
cols = start_col + col_offsets[None, :]
87+
88+
# Create masks for out-of-bounds accesses
89+
row_mask = rows < n_rows
90+
col_mask = cols < n_cols
91+
mask = row_mask & col_mask
92+
93+
# Compute memory offsets for row-major layout (rows, cols)
94+
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
95+
96+
# Compute memory offsets for column-major layout (cols, rows)
97+
col_major_offsets = (cols * n_rows + rows).to(tl.int32)
98+
99+
# Load the entire block in a single operation
100+
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
101+
102+
# ----------------------------------------------------
103+
# Row-wise normalization
104+
# ----------------------------------------------------
105+
# Calculate the absolute values of elements in the block
106+
x_block_abs = tl.abs(x_block)
107+
108+
# Find the maximum absolute value for each row
109+
# We use a small epsilon to avoid division by zero
110+
epsilon = 1e-10
111+
row_max_abs = tl.max(x_block_abs, axis=1) + epsilon
112+
113+
# Normalize each row by its maximum absolute value
114+
# Broadcasting row_max_abs to match x_block's shape
115+
row_normalized = x_block / row_max_abs[:, None]
116+
117+
# ----------------------------------------------------
118+
# Column-wise normalization
119+
# ----------------------------------------------------
120+
# Find the maximum absolute value for each column
121+
col_max_abs = tl.max(x_block_abs, axis=0) + epsilon
122+
123+
# Normalize each column by its maximum absolute value
124+
# Broadcasting col_max_abs to match x_block's shape
125+
col_normalized = x_block / col_max_abs[None, :]
126+
127+
# Store the row-normalized result in row-major format
128+
tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask)
129+
130+
# Store the column-normalized result in column-major format
131+
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
132+
133+
# Create 1D ranges for storing row and column max values
134+
row_indices = start_row + tl.arange(0, TILE_SIZE)
135+
col_indices = start_col + tl.arange(0, TILE_SIZE)
136+
137+
# Create masks for valid rows and columns
138+
row_mask = row_indices < n_rows
139+
col_mask = col_indices < n_cols
140+
141+
# Vasiliy - deviating from Claude here for much simpler code
142+
row_scale_start_ptr = row_max_abs_ptr + (pid_row * n_cols) + pid_col
143+
row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE)
144+
# TODO(future): mask
145+
tl.store(row_scale_start_ptr + row_scale_indices, row_max_abs)
146+
147+
# Vasiliy - deviating from Claude here for much simpler code
148+
col_scale_start_ptr = col_max_abs_ptr + (pid_col * n_rows) + pid_row
149+
col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE)
150+
# TODO(future): mask
151+
tl.store(col_scale_start_ptr + col_scale_indices, col_max_abs)
152+
153+
154+
# Function to launch the kernel
155+
def normalize_tiled(x, tile_size=32):
156+
# Get tensor shape
157+
n_rows, n_cols = x.shape
158+
159+
# Create output tensors (both row-major and column-major)
160+
output_row_major = torch.empty_like(x)
161+
output_col_major = torch.empty((n_cols, n_rows), dtype=x.dtype, device=x.device)
162+
163+
# Create tensors for row-wise and column-wise maximum absolute values
164+
row_max_abs = torch.empty(
165+
n_rows, n_cols // tile_size, dtype=x.dtype, device=x.device
166+
)
167+
col_max_abs = torch.empty(
168+
n_cols, n_rows // tile_size, dtype=x.dtype, device=x.device
169+
)
170+
171+
# Calculate grid dimensions based on tile size
172+
grid_rows = triton.cdiv(n_rows, tile_size)
173+
grid_cols = triton.cdiv(n_cols, tile_size)
174+
175+
# Launch the kernel
176+
normalization_kernel[(grid_rows, grid_cols)](
177+
x_ptr=x,
178+
output_row_major_ptr=output_row_major,
179+
output_col_major_ptr=output_col_major,
180+
row_max_abs_ptr=row_max_abs,
181+
col_max_abs_ptr=col_max_abs,
182+
n_rows=n_rows,
183+
n_cols=n_cols,
184+
TILE_SIZE=tile_size,
185+
)
186+
187+
return (
188+
output_row_major,
189+
output_col_major.t(),
190+
row_max_abs.reshape(-1, 1),
191+
col_max_abs.reshape(-1, 1),
192+
)
193+
194+
195+
def run(
196+
M: int = 4096,
197+
K: int = 2048,
198+
BLOCK_SIZE: int = 32,
199+
):
200+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
201+
print(f"GPU: {torch.cuda.get_device_name(0)}")
202+
print(f"torch version: {torch.__version__}")
203+
print(f"triton version: {triton.__version__}")
204+
205+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
206+
207+
scale_dim0_dim1_c = torch.compile(scale_dim0_dim1_reference)
208+
209+
# reference implementation (plain PyTorch + torch.compile)
210+
x_d0, x_d1, amax_d0, amax_d1 = scale_dim0_dim1_c(x, BLOCK_SIZE)
211+
x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * amax_d0).reshape(x_d0.shape)
212+
x_d1_and_back = (
213+
(x_d1.t().reshape(-1, BLOCK_SIZE) * amax_d1).reshape(x_d1.t().shape).t()
214+
)
215+
216+
sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back)
217+
sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back)
218+
print(
219+
f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}"
220+
)
221+
assert (
222+
sqnr_bf16_vs_dim0_ref > 50 and sqnr_bf16_vs_dim1_ref > 50
223+
), "reference normlization numerics are incorrect"
224+
225+
# basic triton kernel
226+
x_d0_t, x_d1_t, amax_d0_t, amax_d1_t = normalize_tiled(x, tile_size=BLOCK_SIZE)
227+
228+
# ensure bitwise equivalency of outputs with reference
229+
torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0)
230+
torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0)
231+
torch.testing.assert_close(amax_d0, amax_d0_t, atol=0, rtol=0)
232+
torch.testing.assert_close(amax_d1, amax_d1_t, atol=0, rtol=0)
233+
print("normalized reference vs normalized triton are bitwise equivalent")
234+
235+
if False:
236+
# for debugging
237+
sqnr_x_d0_ref_vs_t = compute_error(x_d0, x_d0_t)
238+
print("sqnr_x_d0_t", sqnr_x_d0_ref_vs_t)
239+
sqnr_amax_d0_vs_t = compute_error(amax_d0, amax_d0_t)
240+
print("sqnr_amax_d0_t", sqnr_amax_d0_vs_t)
241+
sqnr_x_d1_ref_vs_t = compute_error(x_d1, x_d1_t)
242+
print("sqnr_x_d1_t", sqnr_x_d1_ref_vs_t)
243+
sqnr_amax_d1_vs_t = compute_error(amax_d1, amax_d1_t)
244+
print("sqnr_amax_d1_t", sqnr_amax_d1_vs_t)
245+
246+
# now, measure performance
247+
248+
# warm up
249+
for _ in range(2):
250+
__ = scale_dim0_dim1_reference(x, BLOCK_SIZE)
251+
time_reference_compile_us = benchmark_cuda_function_in_microseconds(
252+
lambda x, b: scale_dim0_dim1_c(x, b), x, BLOCK_SIZE
253+
)
254+
255+
# warm up
256+
for _ in range(2):
257+
__ = normalize_tiled(x, tile_size=BLOCK_SIZE)
258+
time_triton_us = benchmark_cuda_function_in_microseconds(
259+
lambda x, b: normalize_tiled(x, tile_size=BLOCK_SIZE), x, BLOCK_SIZE
260+
)
261+
262+
# calculate bytes read/written
263+
bytes_per_el_bf16 = 2
264+
bytes_per_el_fp8 = 1
265+
triton_bytes_read = x.numel() * bytes_per_el_bf16
266+
triton_bytes_written = (
267+
sum(x.numel() for x in (x_d0_t, x_d1_t, amax_d0_t, amax_d1_t))
268+
* bytes_per_el_fp8
269+
)
270+
triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / (
271+
time_triton_us / 1e6
272+
)
273+
# TODO read 8.0 TB/s number from roofline_utils.py instead of hardcoding
274+
triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12
275+
276+
print("time_reference_compile_us", time_reference_compile_us)
277+
print("time_triton_us", time_triton_us)
278+
print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps)
279+
# Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor
280+
# can achieve around 50-70% of B200 peak mem bw
281+
print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw)
282+
print("speedup", time_reference_compile_us / time_triton_us)
283+
284+
285+
if __name__ == "__main__":
286+
fire.Fire(run)

0 commit comments

Comments
 (0)