|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +from typing import Tuple |
| 8 | + |
7 | 9 | import numpy as np
|
8 | 10 | import torch
|
9 | 11 | from torch.utils._triton import has_triton
|
@@ -1080,3 +1082,240 @@ def _(uint8_data):
|
1080 | 1082 | def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
|
1081 | 1083 | # Dummy placeholder op for torch < 2.4
|
1082 | 1084 | raise AssertionError("fp6 packing unsupported without torch >= 2.4")
|
| 1085 | + |
| 1086 | + |
| 1087 | +if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): |
| 1088 | + import triton |
| 1089 | + import triton.language as tl |
| 1090 | + |
| 1091 | + @triton.jit |
| 1092 | + def _triton_calculate_scale(x, axis): |
| 1093 | + # We use a small epsilon to avoid division by zero |
| 1094 | + epsilon = 1e-10 |
| 1095 | + |
| 1096 | + # TODO(before land): reuse the constants below instead of hardcoding |
| 1097 | + target_max_pow2 = 8 |
| 1098 | + e8m0_exponent_bias = 127 |
| 1099 | + |
| 1100 | + # Find the maximum absolute value for each row |
| 1101 | + max_abs = tl.max(x, axis=axis) |
| 1102 | + |
| 1103 | + scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 |
| 1104 | + |
| 1105 | + # Clamp to exponents that can be represented in e8m0 |
| 1106 | + scale_e8m0_unbiased = tl.clamp( |
| 1107 | + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias |
| 1108 | + ) |
| 1109 | + |
| 1110 | + # Create the biased e8m0 representation and cast it to 8 bits |
| 1111 | + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias |
| 1112 | + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) |
| 1113 | + |
| 1114 | + # TODO(future PR): add NaN handling here |
| 1115 | + |
| 1116 | + # For now, calculate the scale in floating point. |
| 1117 | + # TODO(future) audit if there is a need to bit shift exponents instead. |
| 1118 | + scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32)) |
| 1119 | + |
| 1120 | + return scale_fp, scale_e8m0_biased |
| 1121 | + |
| 1122 | + @triton.jit |
| 1123 | + def to_mxfp8_across_dim0_and_dim1_kernel( |
| 1124 | + x_ptr, # pointer to input tensor |
| 1125 | + output_row_major_ptr, # pointer to row-major output tensor (row-normalized) |
| 1126 | + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) |
| 1127 | + row_scale_ptr, # pointer to store row-wise maximum absolute values |
| 1128 | + col_scale_ptr, # pointer to store column-wise maximum absolute values |
| 1129 | + n_rows, # number of rows in the tensor |
| 1130 | + n_cols, # number of columns in the tensor |
| 1131 | + TILE_SIZE: tl.constexpr, # tile size as a compile-time constant |
| 1132 | + ): |
| 1133 | + """ |
| 1134 | + credit: mostly Claude, some Vasiliy |
| 1135 | + """ |
| 1136 | + |
| 1137 | + # Get program ID |
| 1138 | + pid_row = tl.program_id(0) |
| 1139 | + pid_col = tl.program_id(1) |
| 1140 | + |
| 1141 | + # Calculate starting row and column for this tile |
| 1142 | + start_row = pid_row * TILE_SIZE |
| 1143 | + start_col = pid_col * TILE_SIZE |
| 1144 | + |
| 1145 | + # Create offsets for the block |
| 1146 | + row_offsets = tl.arange(0, TILE_SIZE) |
| 1147 | + col_offsets = tl.arange(0, TILE_SIZE) |
| 1148 | + |
| 1149 | + # Compute global row/col positions |
| 1150 | + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting |
| 1151 | + cols = start_col + col_offsets[None, :] |
| 1152 | + |
| 1153 | + # Create masks for out-of-bounds accesses |
| 1154 | + row_mask = rows < n_rows |
| 1155 | + col_mask = cols < n_cols |
| 1156 | + mask = row_mask & col_mask |
| 1157 | + |
| 1158 | + # Compute memory offsets for row-major layout (rows, cols) |
| 1159 | + row_major_offsets = (rows * n_cols + cols).to(tl.int32) |
| 1160 | + |
| 1161 | + # Compute memory offsets for column-major layout (cols, rows) |
| 1162 | + col_major_offsets = (cols * n_rows + rows).to(tl.int32) |
| 1163 | + |
| 1164 | + # Load the entire block in a single operation |
| 1165 | + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) |
| 1166 | + |
| 1167 | + # ---------------------------------------------------- |
| 1168 | + # Row-wise normalization |
| 1169 | + # ---------------------------------------------------- |
| 1170 | + # Calculate the absolute values of elements in the block |
| 1171 | + x_block_abs = tl.abs(x_block) |
| 1172 | + |
| 1173 | + # Find the maximum absolute value for each row |
| 1174 | + row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1) |
| 1175 | + |
| 1176 | + # Normalize each row by its maximum absolute value |
| 1177 | + # Broadcasting row_scale to match x_block's shape |
| 1178 | + row_normalized = x_block / row_scale[:, None] |
| 1179 | + |
| 1180 | + # quant to float8 |
| 1181 | + # TODO(this PR): clamp? |
| 1182 | + row_normalized = row_normalized.to(tl.float8e4nv) |
| 1183 | + |
| 1184 | + # ---------------------------------------------------- |
| 1185 | + # Column-wise normalization |
| 1186 | + # ---------------------------------------------------- |
| 1187 | + # Find the maximum absolute value for each column |
| 1188 | + col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0) |
| 1189 | + |
| 1190 | + # Normalize each column by its maximum absolute value |
| 1191 | + # Broadcasting col_scale to match x_block's shape |
| 1192 | + col_normalized = x_block / col_scale[None, :] |
| 1193 | + |
| 1194 | + # quant to float8 |
| 1195 | + # TODO(this PR): clamp? |
| 1196 | + col_normalized = col_normalized.to(tl.float8e4nv) |
| 1197 | + |
| 1198 | + # Store the row-normalized result in row-major format |
| 1199 | + tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask) |
| 1200 | + |
| 1201 | + # Store the column-normalized result in column-major format |
| 1202 | + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) |
| 1203 | + |
| 1204 | + # Create 1D ranges for storing row and column max values |
| 1205 | + row_indices = start_row + tl.arange(0, TILE_SIZE) |
| 1206 | + col_indices = start_col + tl.arange(0, TILE_SIZE) |
| 1207 | + |
| 1208 | + # Create masks for valid rows and columns |
| 1209 | + row_mask = row_indices < n_rows |
| 1210 | + col_mask = col_indices < n_cols |
| 1211 | + |
| 1212 | + # Vasiliy - deviating from Claude here for much simpler code |
| 1213 | + row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col |
| 1214 | + row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE) |
| 1215 | + # TODO(future): mask |
| 1216 | + tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) |
| 1217 | + |
| 1218 | + # Vasiliy - deviating from Claude here for much simpler code |
| 1219 | + col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row |
| 1220 | + col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE) |
| 1221 | + # TODO(future): mask |
| 1222 | + tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
| 1223 | + |
| 1224 | + def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): |
| 1225 | + """ |
| 1226 | + This is a single fused triton kernel to cast `x` to MX across dim0 and dim1. |
| 1227 | + This is useful for MX training with the mxfp8 recipe family. |
| 1228 | +
|
| 1229 | + The kernel loads data in 2d tiles, and performs the necessary casting across both |
| 1230 | + dim0 and dim1 for each tile. |
| 1231 | +
|
| 1232 | + Note that for now, there is only one level of tiling (32 for MX). In the future, |
| 1233 | + we expect that adding an outer tile (of size up to 128 on B200s) can provide a |
| 1234 | + further speedup. |
| 1235 | +
|
| 1236 | + Input: |
| 1237 | + * `x` - input tensor, in row major memory layout |
| 1238 | + * `tile_size` - size of tiles to normalize across, default is 32 for MX recipes |
| 1239 | +
|
| 1240 | + Output: |
| 1241 | + * `output_row_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 |
| 1242 | + * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 |
| 1243 | + * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 |
| 1244 | + * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 |
| 1245 | + """ |
| 1246 | + assert x.is_contiguous(), "`x` must be contiguous" |
| 1247 | + # Get tensor shape |
| 1248 | + n_rows, n_cols = x.shape |
| 1249 | + |
| 1250 | + # Create output tensors (both row-major and column-major) |
| 1251 | + output_row_major = torch.empty_like(x, dtype=torch.float8_e4m3fn) |
| 1252 | + output_col_major = torch.empty( |
| 1253 | + (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device |
| 1254 | + ) |
| 1255 | + |
| 1256 | + # Create tensors for row-wise and column-wise maximum absolute values |
| 1257 | + row_scale = torch.empty( |
| 1258 | + n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device |
| 1259 | + ) |
| 1260 | + col_scale = torch.empty( |
| 1261 | + n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device |
| 1262 | + ) |
| 1263 | + |
| 1264 | + # Calculate grid dimensions based on tile size |
| 1265 | + grid_rows = triton.cdiv(n_rows, tile_size) |
| 1266 | + grid_cols = triton.cdiv(n_cols, tile_size) |
| 1267 | + |
| 1268 | + # Launch the kernel |
| 1269 | + to_mxfp8_across_dim0_and_dim1_kernel[(grid_rows, grid_cols)]( |
| 1270 | + x_ptr=x, |
| 1271 | + output_row_major_ptr=output_row_major, |
| 1272 | + output_col_major_ptr=output_col_major, |
| 1273 | + row_scale_ptr=row_scale, |
| 1274 | + col_scale_ptr=col_scale, |
| 1275 | + n_rows=n_rows, |
| 1276 | + n_cols=n_cols, |
| 1277 | + TILE_SIZE=tile_size, |
| 1278 | + ) |
| 1279 | + |
| 1280 | + return ( |
| 1281 | + output_row_major, |
| 1282 | + output_col_major.t(), |
| 1283 | + row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
| 1284 | + col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
| 1285 | + ) |
| 1286 | + |
| 1287 | + def to_mxfp8_across_dim0_and_dim1_reference( |
| 1288 | + x_hp: torch.Tensor, block_size |
| 1289 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1290 | + """ |
| 1291 | + A reference version of `to_mxfp8_across_dim0_and_dim1`. |
| 1292 | + """ |
| 1293 | + from torchao.prototype.mx_formats.mx_tensor import to_mx |
| 1294 | + |
| 1295 | + # cast across dim0 |
| 1296 | + scale_e8m0_dim0, x_hp_d0_normalized = to_mx( |
| 1297 | + x_hp, torch.float8_e4m3fn, block_size |
| 1298 | + ) |
| 1299 | + scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu) |
| 1300 | + # cast across dim1 |
| 1301 | + x_hp_d1 = x_hp.t().contiguous() |
| 1302 | + scale_e8m0_dim1, x_hp_d1_normalized = to_mx( |
| 1303 | + x_hp_d1, torch.float8_e4m3fn, block_size |
| 1304 | + ) |
| 1305 | + scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) |
| 1306 | + return ( |
| 1307 | + x_hp_d0_normalized, |
| 1308 | + x_hp_d1_normalized.t(), |
| 1309 | + scale_e8m0_dim0, |
| 1310 | + scale_e8m0_dim1, |
| 1311 | + ) |
| 1312 | + |
| 1313 | +else: |
| 1314 | + |
| 1315 | + def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): |
| 1316 | + raise AssertionError("needs torch version 2.4+ and triton") |
| 1317 | + |
| 1318 | + def scale_dim0_dim1_reference( |
| 1319 | + x_hp: torch.Tensor, block_size |
| 1320 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1321 | + raise AssertionError("needs torch version 2.4+ and triton") |
0 commit comments