|
25 | 25 | is_sm_at_least_89, |
26 | 26 | is_sm_at_least_100, |
27 | 27 | ) |
| 28 | +from transformer_nuggets.mx.to_blocked import ( |
| 29 | + to_blocked, |
| 30 | +) |
28 | 31 |
|
29 | 32 | torch.manual_seed(2) |
30 | 33 |
|
@@ -265,6 +268,16 @@ def test_to_blocked(): |
265 | 268 | print(_to_blocked_single(scales)) |
266 | 269 | # looks right! |
267 | 270 |
|
| 271 | +def test_to_blocked_manual_v2(): |
| 272 | + scales = torch.arange(128 * 4 * 2).reshape(128 * 2, 4) / 4 |
| 273 | + torch.set_printoptions(profile="full", linewidth=280) |
| 274 | + print('orig') |
| 275 | + print(scales) |
| 276 | + print('blocked') |
| 277 | + print(to_blocked(scales)) |
| 278 | + # looks right! |
| 279 | + |
| 280 | + |
268 | 281 |
|
269 | 282 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
270 | 283 | @pytest.mark.skipif( |
@@ -324,49 +337,58 @@ def test_scaled_mm_mxfp8_mxtensor(): |
324 | 337 | # * baseline SQNR vs both experiments is ~27 |
325 | 338 | # * SQNR between experiment 1 and 2 is ~155 (near perfect match) |
326 | 339 |
|
327 | | - # M, K, N = 8192, 4096, 8192 |
328 | | - M, K, N = 128, 128, 128 |
329 | | - BLOCK_SIZE = 32 |
330 | | - a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) |
331 | | - b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32).t().contiguous() |
332 | | - |
333 | | - a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) |
334 | | - b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() |
335 | | - a_s0 = a_mx._scale_e8m0.reshape(M, -1) |
336 | | - a_s1 = _to_blocked_single(a_s0) |
337 | | - b_s0 = b_mx._scale_e8m0.reshape(N, -1) |
338 | | - b_s1 = _to_blocked_single(b_s0) |
339 | | - |
340 | | - # ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda") |
341 | | - |
342 | | - out_ref = a_fp32 @ b_fp32.t() |
343 | | - print('baseline', out_ref) |
344 | | - |
345 | | - out_mx_emulated = a_mx @ b_mx |
346 | | - print('mx_emulated', out_mx_emulated) |
347 | | - |
348 | | - out_mx_real = torch._scaled_mm( |
349 | | - a_mx._data, |
350 | | - b_mx._data, |
351 | | - # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? |
352 | | - _to_blocked_single(b_mx._scale_e8m0.reshape(N, -1)), |
353 | | - _to_blocked_single(a_mx._scale_e8m0.reshape(M, -1)), |
354 | | - None, |
355 | | - None, |
356 | | - torch.float32, |
357 | | - False, |
358 | | - None, |
359 | | - None, |
360 | | - DataType.E8M0, |
| 340 | + print() |
| 341 | + shapes_to_try = ( |
| 342 | + (128, 128, 128), |
| 343 | + (128, 256, 512), |
| 344 | + (256, 512, 128), |
| 345 | + (512, 128, 256), |
| 346 | + (4096, 4096, 4096), |
| 347 | + (4096, 8192, 16384), |
| 348 | + (8192, 16384, 4096), |
| 349 | + (16384, 4096, 8192), |
361 | 350 | ) |
362 | | - print('mx_real', out_mx_real) |
363 | | - |
364 | | - sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) |
365 | | - sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) |
366 | | - sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) |
367 | | - print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) |
368 | | - print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) |
369 | | - print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) |
| 351 | + for M, K, N in shapes_to_try: |
| 352 | + print('MKN', M, K, N) |
| 353 | + BLOCK_SIZE = 32 |
| 354 | + a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) |
| 355 | + b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32) |
| 356 | + |
| 357 | + a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) |
| 358 | + b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() |
| 359 | + a_s0 = a_mx._scale_e8m0.reshape(M, -1) |
| 360 | + a_s1 = to_blocked(a_s0) |
| 361 | + b_s0 = b_mx._scale_e8m0.reshape(N, -1) |
| 362 | + b_s1 = to_blocked(b_s0) |
| 363 | + |
| 364 | + out_ref = a_fp32 @ b_fp32.t() |
| 365 | + # print('baseline', out_ref) |
| 366 | + |
| 367 | + out_mx_emulated = a_mx @ b_mx |
| 368 | + # print('mx_emulated', out_mx_emulated) |
| 369 | + |
| 370 | + out_mx_real = torch._scaled_mm( |
| 371 | + a_mx._data, |
| 372 | + b_mx._data, |
| 373 | + # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? |
| 374 | + b_s1, |
| 375 | + a_s1, |
| 376 | + None, |
| 377 | + None, |
| 378 | + torch.float32, |
| 379 | + False, |
| 380 | + None, |
| 381 | + None, |
| 382 | + DataType.E8M0, |
| 383 | + ) |
| 384 | + # print('mx_real', out_mx_real) |
| 385 | + |
| 386 | + sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) |
| 387 | + sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) |
| 388 | + sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) |
| 389 | + print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) |
| 390 | + print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) |
| 391 | + print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) |
370 | 392 |
|
371 | 393 |
|
372 | 394 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
|
0 commit comments