2
2
import torch
3
3
4
4
from torchao .float8 .float8_utils import compute_error
5
- from torchao .ops import mx_fp8_bf16
6
- from torchao .prototype .mx_formats .mx_tensor import MXTensor
5
+ from torchao .ops import mx_fp4_bf16 , mx_fp8_bf16
6
+ from torchao .prototype .mx_formats .mx_tensor import DTYPE_FP4 , MXTensor
7
7
from torchao .prototype .mx_formats .utils import to_blocked
8
- from torchao .utils import (
9
- TORCH_VERSION_AT_LEAST_2_4 ,
10
- is_sm_at_least_100 ,
11
- )
8
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_4 , is_sm_at_least_100
12
9
13
10
if not TORCH_VERSION_AT_LEAST_2_4 :
14
11
pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
15
12
16
13
17
- def run_matrix_test (M : int , K : int , N : int ) -> float :
18
- """
19
- Run matrix multiplication test with given dimensions.
20
-
21
- Args:
22
- M, K, N: Matrix dimensions
23
-
24
- Returns:
25
- float: SQNR (Signal-to-Quantization-Noise Ratio) value
26
- """
14
+ def run_matrix_test (M : int , K : int , N : int , format ) -> float :
27
15
dtype = torch .bfloat16
28
16
device = torch .device ("cuda" )
29
17
30
- # Initialize matrices
31
18
a = torch .rand ((M , K ), dtype = dtype , device = device )
32
19
b = torch .rand ((N , K ), dtype = dtype , device = device )
33
20
34
- # Convert to MX format
35
- a_mx = MXTensor .to_mx (a , torch .float8_e4m3fn , 32 )
36
- b_mx = MXTensor .to_mx (b , torch .float8_e4m3fn , 32 )
21
+ fmt = torch .float8_e4m3fn if format == "fp8" else DTYPE_FP4
22
+ mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
37
23
38
- a_fp8 = a_mx ._data
39
- b_fp8 = b_mx ._data
40
- assert b_fp8 .is_contiguous ()
41
- b_fp8 = b_fp8 .transpose (- 1 , - 2 )
24
+ a_mx = MXTensor .to_mx (a , fmt , 32 )
25
+ b_mx = MXTensor .to_mx (b , fmt , 32 )
42
26
43
- # Get scales
44
- a_scale_e8 = a_mx ._scale_e8m0 .view (M , K // 32 )
45
- b_scale_e8 = b_mx ._scale_e8m0 .view (N , K // 32 )
27
+ a_data = a_mx ._data
28
+ b_data = b_mx ._data
29
+ assert b_data .is_contiguous ()
30
+ b_data = b_data .transpose (- 1 , - 2 )
46
31
47
- a_scale_block = to_blocked (a_scale_e8 )
48
- b_scale_block = to_blocked (b_scale_e8 )
32
+ a_scale = a_mx ._scale_e8m0 .view (M , K // 32 )
33
+ b_scale = b_mx ._scale_e8m0 .view (N , K // 32 )
34
+
35
+ a_scale_block = to_blocked (a_scale )
36
+ b_scale_block = to_blocked (b_scale )
49
37
50
- # Get reference output
51
38
out_hp = a_mx .to_dtype (torch .bfloat16 ) @ b_mx .to_dtype (torch .bfloat16 ).transpose (
52
39
- 1 , - 2
53
40
)
41
+ out = mx_func (a_data , b_data , a_scale_block , b_scale_block )
54
42
55
- # Run implementation
56
- out_e8_fp8 = mx_fp8_bf16 (a_fp8 , b_fp8 , a_scale_block , b_scale_block )
57
-
58
- # Calculate metrics
59
- sqnr = compute_error (out_hp , out_e8_fp8 )
60
-
61
- return sqnr .item ()
43
+ return compute_error (out_hp , out ).item ()
62
44
63
45
64
46
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
68
50
@pytest .mark .parametrize (
69
51
"size" ,
70
52
[
71
- # Small matrices
72
53
(128 , 128 , 128 ),
73
54
(256 , 256 , 256 ),
74
- (384 , 384 , 384 ),
75
- # Medium matrices
55
+ (384 , 384 , 384 ), # Small
76
56
(512 , 512 , 512 ),
77
- (640 , 640 , 640 ),
78
- (768 , 768 , 768 ),
79
- # Large matrices
80
- (896 , 896 , 896 ),
57
+ (768 , 768 , 768 ), # Medium
81
58
(1024 , 1024 , 1024 ),
82
- # Very large matrices
83
- (8192 , 8192 , 8192 ),
84
- # Non-square matrices
59
+ (8192 , 8192 , 8192 ), # Large
85
60
(128 , 256 , 384 ),
86
- (256 , 384 , 512 ),
87
- (384 , 512 , 640 ),
88
- # Non-aligned matrices
61
+ (256 , 384 , 512 ), # Non-square
89
62
(129 , 256 , 384 ),
90
- (256 , 384 , 536 ),
91
- (133 , 512 , 528 ),
63
+ (133 , 512 , 528 ), # Non-aligned
92
64
],
93
65
ids = lambda x : f"{ x [0 ]} x{ x [1 ]} x{ x [2 ]} " ,
94
66
)
95
- def test_matrix_multiplication (size ):
96
- """
97
- Test matrix multiplication with various dimensions.
98
- Verifies that the SQNR meets minimum quality threshold.
99
- """
67
+ @pytest .mark .parametrize ("format" , ["fp8" , "fp4" ])
68
+ def test_matrix_multiplication (size , format ):
100
69
M , K , N = size
101
- sqnr = run_matrix_test (M , K , N )
102
- assert sqnr >= 80.0 , f"SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
70
+ sqnr = run_matrix_test (M , K , N , format )
71
+ threshold = 80.0
72
+ assert (
73
+ sqnr >= threshold
74
+ ), f"{ format } SQNR { sqnr } below threshold for dims { M } x{ K } x{ N } "
0 commit comments