|
4 | 4 | # This source code is licensed under the license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +from typing import Tuple |
| 8 | + |
7 | 9 | import numpy as np
|
8 | 10 | import torch
|
9 | 11 | from torch.utils._triton import has_triton
|
|
12 | 14 | _f32_to_floatx_unpacked,
|
13 | 15 | _floatx_unpacked_to_f32,
|
14 | 16 | )
|
15 |
| -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 |
| 17 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_8 |
16 | 18 |
|
17 | 19 | # TODO(future): if needed, make the below work on previous PyTorch versions,
|
18 | 20 | # just need to hunt down the previous location of `libdevice`. An assert
|
@@ -1080,3 +1082,264 @@ def _(uint8_data):
|
1080 | 1082 | def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
|
1081 | 1083 | # Dummy placeholder op for torch < 2.4
|
1082 | 1084 | raise AssertionError("fp6 packing unsupported without torch >= 2.4")
|
| 1085 | + |
| 1086 | + |
| 1087 | +if TORCH_VERSION_AT_LEAST_2_8 and has_triton(): |
| 1088 | + import triton |
| 1089 | + import triton.language as tl |
| 1090 | + |
| 1091 | + @triton.jit |
| 1092 | + def _triton_calculate_scale(x, axis): |
| 1093 | + # We use a small epsilon to avoid division by zero |
| 1094 | + epsilon = 1e-10 |
| 1095 | + |
| 1096 | + # TODO(before land): reuse the constants below instead of hardcoding |
| 1097 | + target_max_pow2 = 8 |
| 1098 | + e8m0_exponent_bias = 127 |
| 1099 | + # bf16_mbits = 7 |
| 1100 | + # bf16_exp_bias = 127 |
| 1101 | + |
| 1102 | + # Find the maximum absolute value for each row |
| 1103 | + max_abs = tl.max(x, axis=axis) |
| 1104 | + |
| 1105 | + # TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files |
| 1106 | + scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 |
| 1107 | + # max_abs = max_abs + epsilon |
| 1108 | + # max_abs = max_abs.to(tl.bfloat16) |
| 1109 | + # max_abs_int16 = max_abs.to(tl.int16, bitcast=True) |
| 1110 | + # extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias |
| 1111 | + # extracted_pow2 = extracted_pow2 - target_max_pow2 |
| 1112 | + # scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16) |
| 1113 | + |
| 1114 | + # Clamp to exponents that can be represented in e8m0 |
| 1115 | + scale_e8m0_unbiased = tl.clamp( |
| 1116 | + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias |
| 1117 | + ) |
| 1118 | + |
| 1119 | + # Create the biased e8m0 representation and cast it to 8 bits |
| 1120 | + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias |
| 1121 | + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) |
| 1122 | + |
| 1123 | + # TODO(future PR): add NaN handling here |
| 1124 | + |
| 1125 | + # For now, calculate the scale in floating point. |
| 1126 | + # TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1910/files |
| 1127 | + scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32)) |
| 1128 | + |
| 1129 | + return scale_fp, scale_e8m0_biased |
| 1130 | + |
| 1131 | + def _get_mxfp8_dim1_kernel_autotune_configs(): |
| 1132 | + results = [] |
| 1133 | + for ROW_TILE_SIZE in (64, 128): |
| 1134 | + for COL_TILE_SIZE in (64, 128): |
| 1135 | + for num_warps in (1, 2, 4): |
| 1136 | + config = triton.Config( |
| 1137 | + { |
| 1138 | + "ROW_TILE_SIZE": ROW_TILE_SIZE, |
| 1139 | + "COL_TILE_SIZE": COL_TILE_SIZE, |
| 1140 | + }, |
| 1141 | + num_warps=num_warps, |
| 1142 | + ) |
| 1143 | + results.append(config) |
| 1144 | + return results |
| 1145 | + |
| 1146 | + @triton.autotune( |
| 1147 | + configs=_get_mxfp8_dim1_kernel_autotune_configs(), |
| 1148 | + key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"], |
| 1149 | + ) |
| 1150 | + @triton.jit |
| 1151 | + def to_mxfp8_dim1_kernel( |
| 1152 | + x_ptr, # pointer to input tensor |
| 1153 | + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) |
| 1154 | + col_scale_ptr, # pointer to store column-wise maximum absolute values |
| 1155 | + n_rows, # number of rows in the tensor |
| 1156 | + n_cols, # number of columns in the tensor |
| 1157 | + ROW_TILE_SIZE: tl.constexpr, # can be autotuned |
| 1158 | + COL_TILE_SIZE: tl.constexpr, # can be autotuned |
| 1159 | + INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX |
| 1160 | + ): |
| 1161 | + # TODO(future): better name |
| 1162 | + BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE |
| 1163 | + |
| 1164 | + # TODO(future): better name |
| 1165 | + RENAME_ME_TILE_SIZE: tl.constexpr = ( |
| 1166 | + ROW_TILE_SIZE * COL_TILE_SIZE // INNER_BLOCK_SIZE |
| 1167 | + ) |
| 1168 | + |
| 1169 | + # Get program ID |
| 1170 | + pid_row = tl.program_id(0) |
| 1171 | + pid_col = tl.program_id(1) |
| 1172 | + |
| 1173 | + # Calculate starting row and column for this tile |
| 1174 | + start_row = pid_row * ROW_TILE_SIZE |
| 1175 | + start_col = pid_col * COL_TILE_SIZE |
| 1176 | + |
| 1177 | + # Create offsets for the block |
| 1178 | + row_offsets = tl.arange(0, ROW_TILE_SIZE) |
| 1179 | + col_offsets = tl.arange(0, COL_TILE_SIZE) |
| 1180 | + |
| 1181 | + # Compute global row/col positions |
| 1182 | + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting |
| 1183 | + cols = start_col + col_offsets[None, :] |
| 1184 | + |
| 1185 | + # Create masks for out-of-bounds accesses |
| 1186 | + row_mask = rows < n_rows |
| 1187 | + col_mask = cols < n_cols |
| 1188 | + mask = row_mask & col_mask |
| 1189 | + |
| 1190 | + # Compute memory offsets for row-major layout (rows, cols) |
| 1191 | + row_major_offsets = (rows * n_cols + cols).to(tl.int32) |
| 1192 | + |
| 1193 | + # Compute memory offsets for column-major layout (cols, rows) |
| 1194 | + col_major_offsets = (cols * n_rows + rows).to(tl.int32) |
| 1195 | + |
| 1196 | + # Load the entire block in a single operation |
| 1197 | + # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) |
| 1198 | + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) |
| 1199 | + |
| 1200 | + # Transpose dim0 and dim1 |
| 1201 | + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) |
| 1202 | + x_block_t = tl.trans(x_block) |
| 1203 | + |
| 1204 | + # Reshape to inner tile size |
| 1205 | + # TODO: make this generic and nice |
| 1206 | + # inner_block_size = 32 |
| 1207 | + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) |
| 1208 | + x_block_t_r = x_block_t.reshape(RENAME_ME_TILE_SIZE, INNER_BLOCK_SIZE) |
| 1209 | + |
| 1210 | + # Calculate the absolute values of elements in the block |
| 1211 | + x_block_abs_t_r = tl.abs(x_block_t_r) |
| 1212 | + |
| 1213 | + # Find the maximum absolute value for each column |
| 1214 | + # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) |
| 1215 | + col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) |
| 1216 | + |
| 1217 | + # Divide each column by scale |
| 1218 | + # Broadcasting col_scale to match x_block's shape |
| 1219 | + # x_block_t shape (COL_TILE_SIZE, ROW_TILE_SIZE) |
| 1220 | + # x_block_t_r shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) |
| 1221 | + # col_scale shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, 1) |
| 1222 | + col_normalized_t_r = x_block_t_r / col_scale_r[:, None] |
| 1223 | + |
| 1224 | + # Reshape back to original tile size |
| 1225 | + col_normalized_t = tl.reshape(col_normalized_t_r, COL_TILE_SIZE, ROW_TILE_SIZE) |
| 1226 | + |
| 1227 | + # Undo the transpose |
| 1228 | + col_normalized = tl.trans(col_normalized_t) |
| 1229 | + |
| 1230 | + # Quantize to float8 |
| 1231 | + col_normalized = col_normalized.to(tl.float8e4nv) |
| 1232 | + |
| 1233 | + # Store the column-normalized result in column-major format |
| 1234 | + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) |
| 1235 | + |
| 1236 | + # reshape col_scale_e8m0_r to col_scale_e8m0 |
| 1237 | + # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,) |
| 1238 | + # col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE) |
| 1239 | + col_scale_e8m0 = col_scale_e8m0_r.reshape( |
| 1240 | + COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE |
| 1241 | + ) |
| 1242 | + |
| 1243 | + col_scale_start_offsets = ( |
| 1244 | + (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) |
| 1245 | + * BLOCKS_PER_ROW_TILE # number of blocks seen so far |
| 1246 | + + pid_row * BLOCKS_PER_ROW_TILE # increment ROW_TILE_SIZE |
| 1247 | + ) |
| 1248 | + |
| 1249 | + col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets |
| 1250 | + |
| 1251 | + # calculate col_scale_indices, this is a bit convoluted |
| 1252 | + # start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE] |
| 1253 | + # from example: [0, 1, 2, 3, 4, 5, 6, 7] |
| 1254 | + col_scale_indices = tl.arange( |
| 1255 | + 0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE |
| 1256 | + ) |
| 1257 | + |
| 1258 | + # needs better name |
| 1259 | + jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE |
| 1260 | + |
| 1261 | + # example transformation (specifics depend on tile sizes): |
| 1262 | + # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] |
| 1263 | + col_scale_indices = col_scale_indices + ( |
| 1264 | + tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col |
| 1265 | + ).to(tl.int32) |
| 1266 | + |
| 1267 | + # TODO(future): mask |
| 1268 | + tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) |
| 1269 | + |
| 1270 | + def to_mxfp8_dim1(x, inner_block_size=32): |
| 1271 | + """ |
| 1272 | + Input: |
| 1273 | + * `x` - input tensor, in row major memory layout |
| 1274 | + * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes |
| 1275 | +
|
| 1276 | + Output: |
| 1277 | + * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 |
| 1278 | + * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 |
| 1279 | + """ |
| 1280 | + assert x.is_contiguous(), "`x` must be contiguous" |
| 1281 | + assert x.dtype == torch.bfloat16 |
| 1282 | + assert inner_block_size <= 32 |
| 1283 | + |
| 1284 | + # Get tensor shape |
| 1285 | + n_rows, n_cols = x.shape |
| 1286 | + |
| 1287 | + # Create output tensors |
| 1288 | + output_col_major = torch.empty( |
| 1289 | + (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device |
| 1290 | + ) |
| 1291 | + |
| 1292 | + # Create scale tensors |
| 1293 | + col_scale = torch.empty( |
| 1294 | + n_cols, n_rows // inner_block_size, dtype=torch.uint8, device=x.device |
| 1295 | + ) |
| 1296 | + |
| 1297 | + # Calculate grid dimensions based on tile size |
| 1298 | + grid = lambda META: ( |
| 1299 | + triton.cdiv(n_rows, META["ROW_TILE_SIZE"]), |
| 1300 | + triton.cdiv(n_cols, META["COL_TILE_SIZE"]), |
| 1301 | + ) |
| 1302 | + |
| 1303 | + # Launch the kernel |
| 1304 | + to_mxfp8_dim1_kernel[grid]( |
| 1305 | + x_ptr=x, |
| 1306 | + output_col_major_ptr=output_col_major, |
| 1307 | + col_scale_ptr=col_scale, |
| 1308 | + n_rows=n_rows, |
| 1309 | + n_cols=n_cols, |
| 1310 | + INNER_BLOCK_SIZE=inner_block_size, |
| 1311 | + ) |
| 1312 | + |
| 1313 | + return ( |
| 1314 | + output_col_major.t(), |
| 1315 | + col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), |
| 1316 | + ) |
| 1317 | + |
| 1318 | + def to_mxfp8_dim1_reference( |
| 1319 | + x_hp: torch.Tensor, block_size |
| 1320 | + ) -> Tuple[torch.Tensor, torch.Tensor]: |
| 1321 | + """ |
| 1322 | + A reference version of `to_mxfp8_dim1`. |
| 1323 | + """ |
| 1324 | + from torchao.prototype.mx_formats.mx_tensor import to_mx |
| 1325 | + |
| 1326 | + # cast across dim1 |
| 1327 | + x_hp_d1 = x_hp.t().contiguous() |
| 1328 | + scale_e8m0_dim1, x_hp_d1_normalized = to_mx( |
| 1329 | + x_hp_d1, torch.float8_e4m3fn, block_size |
| 1330 | + ) |
| 1331 | + scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) |
| 1332 | + return ( |
| 1333 | + x_hp_d1_normalized.t(), |
| 1334 | + scale_e8m0_dim1, |
| 1335 | + ) |
| 1336 | + |
| 1337 | +else: |
| 1338 | + |
| 1339 | + def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): |
| 1340 | + raise AssertionError("needs torch version 2.8+ and triton") |
| 1341 | + |
| 1342 | + def scale_dim0_dim1_reference( |
| 1343 | + x_hp: torch.Tensor, block_size |
| 1344 | + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1345 | + raise AssertionError("needs torch version 2.8+ and triton") |
0 commit comments