|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +""" |
| 8 | +Starting with https://github.com/vkuzo/pytorch_scripts/blob/main/mx_cast_poc/20250305_mx_dim0_dim1_cast.py |
| 9 | +and making it nice. |
| 10 | +""" |
| 11 | + |
| 12 | +from typing import Callable, Tuple |
| 13 | + |
| 14 | +import fire |
| 15 | +import torch |
| 16 | +import triton |
| 17 | +import triton.language as tl |
| 18 | +from torch._inductor.utils import do_bench_using_profiling |
| 19 | + |
| 20 | +from torchao.prototype.mx_formats.constants import ( |
| 21 | + E8M0_EXPONENT_BIAS, |
| 22 | + F8E4M3_MAX_POW2, |
| 23 | + F32_MIN_NORMAL, |
| 24 | +) |
| 25 | + |
| 26 | +torch.manual_seed(0) |
| 27 | + |
| 28 | + |
| 29 | +def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: |
| 30 | + """Thin wrapper around do_bench_using_profiling""" |
| 31 | + no_args = lambda: func(*args, **kwargs) |
| 32 | + time = do_bench_using_profiling(no_args) |
| 33 | + return time * 1e3 |
| 34 | + |
| 35 | + |
| 36 | +def compute_error(x, y): |
| 37 | + Ps = torch.linalg.norm(x) |
| 38 | + Pn = torch.linalg.norm(x - y) |
| 39 | + return 20 * torch.log10(Ps / Pn) |
| 40 | + |
| 41 | + |
| 42 | +def get_scale_reference(x_hp): |
| 43 | + # TODO(before land): reuse code with mx_tensor.py::to_mx |
| 44 | + # TODO(future PR): test block of all-zeros |
| 45 | + # TODO(future PR): support other rounding modes (currently only supports floor) |
| 46 | + |
| 47 | + # TODO(future PR): support other dtypes |
| 48 | + target_max_pow2 = F8E4M3_MAX_POW2 |
| 49 | + |
| 50 | + epsilon = 1e-10 |
| 51 | + max_abs = torch.amax(x_hp, dim=1).unsqueeze(1) |
| 52 | + |
| 53 | + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + epsilon)) - target_max_pow2 |
| 54 | + |
| 55 | + # Clamp to exponents that can be represented in e8m0 |
| 56 | + scale_e8m0_unbiased = torch.clamp( |
| 57 | + scale_e8m0_unbiased, min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS |
| 58 | + ) |
| 59 | + |
| 60 | + # Create the biased e8m0 representation and cast it to 8 bits |
| 61 | + scale_e8m0_biased = scale_e8m0_unbiased + E8M0_EXPONENT_BIAS |
| 62 | + scale_e8m0_biased = scale_e8m0_biased.to(torch.uint8) |
| 63 | + |
| 64 | + # TODO(future PR): add NaN handling here |
| 65 | + |
| 66 | + # For now, calculate the scale in floating point. |
| 67 | + # TODO(future) audit if there is a need to bit shift exponents instead. |
| 68 | + scale_fp = torch.pow( |
| 69 | + torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device), |
| 70 | + scale_e8m0_unbiased, |
| 71 | + ) |
| 72 | + |
| 73 | + # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the |
| 74 | + # float32 denormal range. For now, manually adjust the fp scale. This is |
| 75 | + # relevant if all of the incoming block values are zeroes. |
| 76 | + # See https://github.com/pytorch/pytorch/issues/125557 for details. |
| 77 | + # Note: it would be more correct to set the minimum to 2**-127, but this |
| 78 | + # does not work in triton either as it looks like subnormal value handling |
| 79 | + # has some gaps. So, for now just set to the minimum normal value. |
| 80 | + scale_fp = torch.clamp(scale_fp, min=F32_MIN_NORMAL) |
| 81 | + |
| 82 | + # return max_abs |
| 83 | + return scale_fp |
| 84 | + |
| 85 | + |
| 86 | +def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]: |
| 87 | + x_hp_block = x_hp.reshape(-1, block_size) |
| 88 | + x_hp_block_abs = x_hp_block.abs() |
| 89 | + scale = get_scale_reference(x_hp_block_abs) |
| 90 | + x_hp_block_normalized = x_hp_block / scale |
| 91 | + x_hp_normalized = x_hp_block_normalized.reshape(x_hp.shape) |
| 92 | + return x_hp_normalized.to(x_hp.dtype), scale |
| 93 | + |
| 94 | + |
| 95 | +def scale_dim0_dim1_reference( |
| 96 | + x_hp: torch.Tensor, block_size |
| 97 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 98 | + # normalize across dim0 |
| 99 | + x_hp_d0_normalized, amax_dim0 = scale_dim0_reference(x_hp, block_size) |
| 100 | + # normalize across dim1 |
| 101 | + x_hp_d1 = x_hp.t().contiguous() |
| 102 | + x_hp_d1_normalized, amax_dim1 = scale_dim0_reference(x_hp_d1, block_size) |
| 103 | + return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1 |
| 104 | + |
| 105 | + |
| 106 | +@triton.jit |
| 107 | +def _triton_calculate_scale(x, axis): |
| 108 | + # We use a small epsilon to avoid division by zero |
| 109 | + epsilon = 1e-10 |
| 110 | + |
| 111 | + # TODO(before land): reuse the constants below instead of hardcoding |
| 112 | + target_max_pow2 = 8 |
| 113 | + e8m0_exponent_bias = 127 |
| 114 | + |
| 115 | + # Find the maximum absolute value for each row |
| 116 | + max_abs = tl.max(x, axis=axis) |
| 117 | + |
| 118 | + scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 |
| 119 | + |
| 120 | + # Clamp to exponents that can be represented in e8m0 |
| 121 | + scale_e8m0_unbiased = tl.clamp( |
| 122 | + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias |
| 123 | + ) |
| 124 | + |
| 125 | + # Create the biased e8m0 representation and cast it to 8 bits |
| 126 | + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias |
| 127 | + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) |
| 128 | + |
| 129 | + # TODO(future PR): add NaN handling here |
| 130 | + |
| 131 | + # For now, calculate the scale in floating point. |
| 132 | + # TODO(future) audit if there is a need to bit shift exponents instead. |
| 133 | + scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32)) |
| 134 | + |
| 135 | + return scale_fp |
| 136 | + |
| 137 | + |
| 138 | +@triton.jit |
| 139 | +def normalization_kernel( |
| 140 | + x_ptr, # pointer to input tensor |
| 141 | + output_row_major_ptr, # pointer to row-major output tensor (row-normalized) |
| 142 | + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) |
| 143 | + row_scale_ptr, # pointer to store row-wise maximum absolute values |
| 144 | + col_scale_ptr, # pointer to store column-wise maximum absolute values |
| 145 | + n_rows, # number of rows in the tensor |
| 146 | + n_cols, # number of columns in the tensor |
| 147 | + TILE_SIZE: tl.constexpr, # tile size as a compile-time constant |
| 148 | +): |
| 149 | + """ |
| 150 | + credit: mostly Claude, some Vasiliy |
| 151 | + """ |
| 152 | + |
| 153 | + # Get program ID |
| 154 | + pid_row = tl.program_id(0) |
| 155 | + pid_col = tl.program_id(1) |
| 156 | + |
| 157 | + # Calculate starting row and column for this tile |
| 158 | + start_row = pid_row * TILE_SIZE |
| 159 | + start_col = pid_col * TILE_SIZE |
| 160 | + |
| 161 | + # Create offsets for the block |
| 162 | + row_offsets = tl.arange(0, TILE_SIZE) |
| 163 | + col_offsets = tl.arange(0, TILE_SIZE) |
| 164 | + |
| 165 | + # Compute global row/col positions |
| 166 | + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting |
| 167 | + cols = start_col + col_offsets[None, :] |
| 168 | + |
| 169 | + # Create masks for out-of-bounds accesses |
| 170 | + row_mask = rows < n_rows |
| 171 | + col_mask = cols < n_cols |
| 172 | + mask = row_mask & col_mask |
| 173 | + |
| 174 | + # Compute memory offsets for row-major layout (rows, cols) |
| 175 | + row_major_offsets = (rows * n_cols + cols).to(tl.int32) |
| 176 | + |
| 177 | + # Compute memory offsets for column-major layout (cols, rows) |
| 178 | + col_major_offsets = (cols * n_rows + rows).to(tl.int32) |
| 179 | + |
| 180 | + # Load the entire block in a single operation |
| 181 | + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) |
| 182 | + |
| 183 | + # ---------------------------------------------------- |
| 184 | + # Row-wise normalization |
| 185 | + # ---------------------------------------------------- |
| 186 | + # Calculate the absolute values of elements in the block |
| 187 | + x_block_abs = tl.abs(x_block) |
| 188 | + |
| 189 | + # Find the maximum absolute value for each row |
| 190 | + row_scale = _triton_calculate_scale(x_block_abs, axis=1) |
| 191 | + |
| 192 | + # Normalize each row by its maximum absolute value |
| 193 | + # Broadcasting row_scale to match x_block's shape |
| 194 | + row_normalized = x_block / row_scale[:, None] |
| 195 | + |
| 196 | + # ---------------------------------------------------- |
| 197 | + # Column-wise normalization |
| 198 | + # ---------------------------------------------------- |
| 199 | + # Find the maximum absolute value for each column |
| 200 | + col_scale = _triton_calculate_scale(x_block_abs, axis=0) |
| 201 | + |
| 202 | + # Normalize each column by its maximum absolute value |
| 203 | + # Broadcasting col_scale to match x_block's shape |
| 204 | + col_normalized = x_block / col_scale[None, :] |
| 205 | + |
| 206 | + # Store the row-normalized result in row-major format |
| 207 | + tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask) |
| 208 | + |
| 209 | + # Store the column-normalized result in column-major format |
| 210 | + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) |
| 211 | + |
| 212 | + # Create 1D ranges for storing row and column max values |
| 213 | + row_indices = start_row + tl.arange(0, TILE_SIZE) |
| 214 | + col_indices = start_col + tl.arange(0, TILE_SIZE) |
| 215 | + |
| 216 | + # Create masks for valid rows and columns |
| 217 | + row_mask = row_indices < n_rows |
| 218 | + col_mask = col_indices < n_cols |
| 219 | + |
| 220 | + # Vasiliy - deviating from Claude here for much simpler code |
| 221 | + row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col |
| 222 | + row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE) |
| 223 | + # TODO(future): mask |
| 224 | + tl.store(row_scale_start_ptr + row_scale_indices, row_scale) |
| 225 | + |
| 226 | + # Vasiliy - deviating from Claude here for much simpler code |
| 227 | + col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row |
| 228 | + col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE) |
| 229 | + # TODO(future): mask |
| 230 | + tl.store(col_scale_start_ptr + col_scale_indices, col_scale) |
| 231 | + |
| 232 | + |
| 233 | +# Function to launch the kernel |
| 234 | +def normalize_tiled(x, tile_size=32): |
| 235 | + # Get tensor shape |
| 236 | + n_rows, n_cols = x.shape |
| 237 | + |
| 238 | + # Create output tensors (both row-major and column-major) |
| 239 | + output_row_major = torch.empty_like(x) |
| 240 | + output_col_major = torch.empty((n_cols, n_rows), dtype=x.dtype, device=x.device) |
| 241 | + |
| 242 | + # Create tensors for row-wise and column-wise maximum absolute values |
| 243 | + row_scale = torch.empty( |
| 244 | + n_rows, n_cols // tile_size, dtype=torch.float, device=x.device |
| 245 | + ) |
| 246 | + col_scale = torch.empty( |
| 247 | + n_cols, n_rows // tile_size, dtype=torch.float, device=x.device |
| 248 | + ) |
| 249 | + |
| 250 | + # Calculate grid dimensions based on tile size |
| 251 | + grid_rows = triton.cdiv(n_rows, tile_size) |
| 252 | + grid_cols = triton.cdiv(n_cols, tile_size) |
| 253 | + |
| 254 | + # Launch the kernel |
| 255 | + normalization_kernel[(grid_rows, grid_cols)]( |
| 256 | + x_ptr=x, |
| 257 | + output_row_major_ptr=output_row_major, |
| 258 | + output_col_major_ptr=output_col_major, |
| 259 | + row_scale_ptr=row_scale, |
| 260 | + col_scale_ptr=col_scale, |
| 261 | + n_rows=n_rows, |
| 262 | + n_cols=n_cols, |
| 263 | + TILE_SIZE=tile_size, |
| 264 | + ) |
| 265 | + |
| 266 | + return ( |
| 267 | + output_row_major, |
| 268 | + output_col_major.t(), |
| 269 | + row_scale.reshape(-1, 1), |
| 270 | + col_scale.reshape(-1, 1), |
| 271 | + ) |
| 272 | + |
| 273 | + |
| 274 | +def run( |
| 275 | + M: int = 4096, |
| 276 | + K: int = 2048, |
| 277 | + BLOCK_SIZE: int = 32, |
| 278 | +): |
| 279 | + print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}") |
| 280 | + print(f"GPU: {torch.cuda.get_device_name(0)}") |
| 281 | + print(f"torch version: {torch.__version__}") |
| 282 | + print(f"triton version: {triton.__version__}") |
| 283 | + |
| 284 | + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") |
| 285 | + |
| 286 | + scale_dim0_dim1_c = torch.compile(scale_dim0_dim1_reference) |
| 287 | + |
| 288 | + # reference implementation (plain PyTorch + torch.compile) |
| 289 | + x_d0, x_d1, amax_d0, amax_d1 = scale_dim0_dim1_c(x, BLOCK_SIZE) |
| 290 | + x_d0_and_back = (x_d0.reshape(-1, BLOCK_SIZE) * amax_d0).reshape(x_d0.shape) |
| 291 | + x_d1_and_back = ( |
| 292 | + (x_d1.t().reshape(-1, BLOCK_SIZE) * amax_d1).reshape(x_d1.t().shape).t() |
| 293 | + ) |
| 294 | + |
| 295 | + sqnr_bf16_vs_dim0_ref = compute_error(x, x_d0_and_back) |
| 296 | + sqnr_bf16_vs_dim1_ref = compute_error(x, x_d1_and_back) |
| 297 | + print( |
| 298 | + f"bf16 vs normalized reference sqnrs: dim0 {sqnr_bf16_vs_dim0_ref}, dim1 {sqnr_bf16_vs_dim1_ref}" |
| 299 | + ) |
| 300 | + assert ( |
| 301 | + sqnr_bf16_vs_dim0_ref > 50 and sqnr_bf16_vs_dim1_ref > 50 |
| 302 | + ), "reference normlization numerics are incorrect" |
| 303 | + |
| 304 | + # basic triton kernel |
| 305 | + x_d0_t, x_d1_t, amax_d0_t, amax_d1_t = normalize_tiled(x, tile_size=BLOCK_SIZE) |
| 306 | + |
| 307 | + # ensure bitwise equivalency of outputs with reference |
| 308 | + torch.testing.assert_close(x_d0, x_d0_t, atol=0, rtol=0) |
| 309 | + torch.testing.assert_close(x_d1, x_d1_t, atol=0, rtol=0) |
| 310 | + print(1, amax_d0.dtype, 2, amax_d0_t.dtype) |
| 311 | + torch.testing.assert_close(amax_d0, amax_d0_t, atol=0, rtol=0) |
| 312 | + torch.testing.assert_close(amax_d1, amax_d1_t, atol=0, rtol=0) |
| 313 | + print("normalized reference vs normalized triton are bitwise equivalent") |
| 314 | + |
| 315 | + if False: |
| 316 | + # for debugging |
| 317 | + sqnr_x_d0_ref_vs_t = compute_error(x_d0, x_d0_t) |
| 318 | + print("sqnr_x_d0_t", sqnr_x_d0_ref_vs_t) |
| 319 | + sqnr_amax_d0_vs_t = compute_error(amax_d0, amax_d0_t) |
| 320 | + print("sqnr_amax_d0_t", sqnr_amax_d0_vs_t) |
| 321 | + sqnr_x_d1_ref_vs_t = compute_error(x_d1, x_d1_t) |
| 322 | + print("sqnr_x_d1_t", sqnr_x_d1_ref_vs_t) |
| 323 | + sqnr_amax_d1_vs_t = compute_error(amax_d1, amax_d1_t) |
| 324 | + print("sqnr_amax_d1_t", sqnr_amax_d1_vs_t) |
| 325 | + |
| 326 | + # now, measure performance |
| 327 | + |
| 328 | + # warm up |
| 329 | + for _ in range(2): |
| 330 | + __ = scale_dim0_dim1_reference(x, BLOCK_SIZE) |
| 331 | + time_reference_compile_us = benchmark_cuda_function_in_microseconds( |
| 332 | + lambda x, b: scale_dim0_dim1_c(x, b), x, BLOCK_SIZE |
| 333 | + ) |
| 334 | + |
| 335 | + # warm up |
| 336 | + for _ in range(2): |
| 337 | + __ = normalize_tiled(x, tile_size=BLOCK_SIZE) |
| 338 | + time_triton_us = benchmark_cuda_function_in_microseconds( |
| 339 | + lambda x, b: normalize_tiled(x, tile_size=BLOCK_SIZE), x, BLOCK_SIZE |
| 340 | + ) |
| 341 | + |
| 342 | + # calculate bytes read/written |
| 343 | + bytes_per_el_bf16 = 2 |
| 344 | + bytes_per_el_fp8 = 1 |
| 345 | + triton_bytes_read = x.numel() * bytes_per_el_bf16 |
| 346 | + triton_bytes_written = ( |
| 347 | + sum(x.numel() for x in (x_d0_t, x_d1_t, amax_d0_t, amax_d1_t)) |
| 348 | + * bytes_per_el_fp8 |
| 349 | + ) |
| 350 | + triton_achieved_mem_bw_gbps = (triton_bytes_read + triton_bytes_written) / ( |
| 351 | + time_triton_us / 1e6 |
| 352 | + ) |
| 353 | + # TODO read 8.0 TB/s number from roofline_utils.py instead of hardcoding |
| 354 | + triton_pct_peak_mem_bw = triton_achieved_mem_bw_gbps / 8.0e12 |
| 355 | + |
| 356 | + print("time_reference_compile_us", time_reference_compile_us) |
| 357 | + print("time_triton_us", time_triton_us) |
| 358 | + print("triton_achieved_mem_bw_gbps", triton_achieved_mem_bw_gbps) |
| 359 | + # Note: as of 2025-03-11, inductor code for adding 1.0 to a large bf16 tensor |
| 360 | + # can achieve around 50-70% of B200 peak mem bw |
| 361 | + print("triton_pct_peak_mem_bw", triton_pct_peak_mem_bw) |
| 362 | + print("speedup", time_reference_compile_us / time_triton_us) |
| 363 | + |
| 364 | + |
| 365 | +if __name__ == "__main__": |
| 366 | + fire.Fire(run) |
0 commit comments