Skip to content

Commit be7806c

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8daa0d0 commit be7806c

File tree

1 file changed

+64
-42
lines changed

1 file changed

+64
-42
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
is_sm_at_least_89,
2626
is_sm_at_least_100,
2727
)
28+
from transformer_nuggets.mx.to_blocked import (
29+
to_blocked,
30+
)
2831

2932
torch.manual_seed(2)
3033

@@ -265,6 +268,16 @@ def test_to_blocked():
265268
print(_to_blocked_single(scales))
266269
# looks right!
267270

271+
def test_to_blocked_manual_v2():
272+
scales = torch.arange(128 * 4 * 2).reshape(128 * 2, 4) / 4
273+
torch.set_printoptions(profile="full", linewidth=280)
274+
print('orig')
275+
print(scales)
276+
print('blocked')
277+
print(to_blocked(scales))
278+
# looks right!
279+
280+
268281

269282
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
270283
@pytest.mark.skipif(
@@ -324,49 +337,58 @@ def test_scaled_mm_mxfp8_mxtensor():
324337
# * baseline SQNR vs both experiments is ~27
325338
# * SQNR between experiment 1 and 2 is ~155 (near perfect match)
326339

327-
# M, K, N = 8192, 4096, 8192
328-
M, K, N = 128, 128, 128
329-
BLOCK_SIZE = 32
330-
a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32)
331-
b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32).t().contiguous()
332-
333-
a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE)
334-
b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t()
335-
a_s0 = a_mx._scale_e8m0.reshape(M, -1)
336-
a_s1 = _to_blocked_single(a_s0)
337-
b_s0 = b_mx._scale_e8m0.reshape(N, -1)
338-
b_s1 = _to_blocked_single(b_s0)
339-
340-
# ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda")
341-
342-
out_ref = a_fp32 @ b_fp32.t()
343-
print('baseline', out_ref)
344-
345-
out_mx_emulated = a_mx @ b_mx
346-
print('mx_emulated', out_mx_emulated)
347-
348-
out_mx_real = torch._scaled_mm(
349-
a_mx._data,
350-
b_mx._data,
351-
# a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel?
352-
_to_blocked_single(b_mx._scale_e8m0.reshape(N, -1)),
353-
_to_blocked_single(a_mx._scale_e8m0.reshape(M, -1)),
354-
None,
355-
None,
356-
torch.float32,
357-
False,
358-
None,
359-
None,
360-
DataType.E8M0,
340+
print()
341+
shapes_to_try = (
342+
(128, 128, 128),
343+
(128, 256, 512),
344+
(256, 512, 128),
345+
(512, 128, 256),
346+
(4096, 4096, 4096),
347+
(4096, 8192, 16384),
348+
(8192, 16384, 4096),
349+
(16384, 4096, 8192),
361350
)
362-
print('mx_real', out_mx_real)
363-
364-
sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated)
365-
sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real)
366-
sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real)
367-
print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx)
368-
print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx)
369-
print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx)
351+
for M, K, N in shapes_to_try:
352+
print('MKN', M, K, N)
353+
BLOCK_SIZE = 32
354+
a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32)
355+
b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32)
356+
357+
a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE)
358+
b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t()
359+
a_s0 = a_mx._scale_e8m0.reshape(M, -1)
360+
a_s1 = to_blocked(a_s0)
361+
b_s0 = b_mx._scale_e8m0.reshape(N, -1)
362+
b_s1 = to_blocked(b_s0)
363+
364+
out_ref = a_fp32 @ b_fp32.t()
365+
# print('baseline', out_ref)
366+
367+
out_mx_emulated = a_mx @ b_mx
368+
# print('mx_emulated', out_mx_emulated)
369+
370+
out_mx_real = torch._scaled_mm(
371+
a_mx._data,
372+
b_mx._data,
373+
# a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel?
374+
b_s1,
375+
a_s1,
376+
None,
377+
None,
378+
torch.float32,
379+
False,
380+
None,
381+
None,
382+
DataType.E8M0,
383+
)
384+
# print('mx_real', out_mx_real)
385+
386+
sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated)
387+
sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real)
388+
sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real)
389+
print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx)
390+
print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx)
391+
print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx)
370392

371393

372394
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

0 commit comments

Comments
 (0)