55# LICENSE file in the root directory of this source tree.
66
77import copy
8- import itertools
98
109import pytest
1110import torch
1615 MXLinearConfig ,
1716 MXLinearRecipeName ,
1817)
19- from torchao .prototype .mx_formats .constants import DTYPE_FP4 , SUPPORTED_ELEM_DTYPES
18+ from torchao .prototype .mx_formats .constants import (
19+ DTYPE_FP4 ,
20+ DTYPE_FP6_E2M3 ,
21+ DTYPE_FP6_E3M2 ,
22+ SUPPORTED_ELEM_DTYPES ,
23+ )
2024from torchao .prototype .mx_formats .mx_linear import (
2125 MXInferenceLinear ,
2226 MXLinear ,
@@ -48,38 +52,65 @@ def run_around_tests():
4852
4953@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
5054@pytest .mark .parametrize (
51- "elem_dtype" , itertools .product (SUPPORTED_ELEM_DTYPES , repeat = 3 )
55+ "elem_dtype" ,
56+ (
57+ # test each dtype
58+ (torch .float8_e4m3fn , torch .float8_e4m3fn , torch .float8_e4m3fn ),
59+ (DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 , DTYPE_FP6_E3M2 ),
60+ (DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 , DTYPE_FP6_E2M3 ),
61+ (DTYPE_FP4 , DTYPE_FP4 , DTYPE_FP4 ),
62+ # only test one type of mixed-dtype overrides, to save testing time
63+ (torch .float8_e4m3fn , DTYPE_FP4 , DTYPE_FP4 ),
64+ ),
5265)
5366@pytest .mark .parametrize ("bias" , [True , False ])
54- @pytest .mark .parametrize ("input_shape" , [(4 , 8 ), (1 , 4 , 8 ), (1 , 1 , 4 , 8 )])
55- def test_linear_eager (elem_dtype , bias , input_shape ):
67+ @pytest .mark .parametrize ("input_shape" , [(128 , 256 ), (1 , 128 , 256 ), (1 , 1 , 128 , 256 )])
68+ @pytest .mark .parametrize ("use_fp8_dim1_cast_triton_kernel" , [False , True ])
69+ def test_linear_eager_vs_hp (
70+ elem_dtype , bias , input_shape , use_fp8_dim1_cast_triton_kernel
71+ ):
5672 """
5773 Smoke test for training linear module with mx weight, compares the following:
5874 * baseline: float32
5975 * experiment: emulated MX
6076 """
77+ if use_fp8_dim1_cast_triton_kernel :
78+ if elem_dtype != (
79+ torch .float8_e4m3fn ,
80+ torch .float8_e4m3fn ,
81+ torch .float8_e4m3fn ,
82+ ):
83+ pytest .skip ("unsupported configuration" )
84+ elif not is_sm_at_least_89 ():
85+ pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
86+
6187 # elem_dtype is a tuple of (input, weight, gradient) dtypes.
6288 grad_shape = list (input_shape )
63- grad_shape [- 1 ] = 8
89+ grad_shape [- 1 ] = 256
6490
6591 m = nn .Sequential (
66- nn .Linear (8 , 8 , bias = bias , device = "cuda" ),
92+ nn .Linear (256 , 256 , bias = bias , device = "cuda" , dtype = torch . bfloat16 ),
6793 )
6894 m_mx = copy .deepcopy (m )
6995 config = MXLinearConfig (
7096 block_size = 4 ,
7197 elem_dtype = elem_dtype [0 ],
7298 elem_dtype_weight_override = elem_dtype [1 ],
7399 elem_dtype_grad_output_override = elem_dtype [2 ],
100+ use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel ,
74101 )
75102 swap_linear_with_mx_linear (m_mx , config = config )
76103
77- x_ref = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
104+ x_ref = torch .randn (
105+ * input_shape , device = "cuda" , dtype = torch .bfloat16
106+ ).requires_grad_ ()
78107 x = copy .deepcopy (x_ref )
79108 g = torch .randn (* grad_shape , device = "cuda" )
80- with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
81- y_ref = m (x_ref )
82- y_mx = m_mx (x )
109+
110+ y_ref = m (x_ref )
111+ y_mx = m_mx (x )
112+
113+ assert y_mx .dtype == x .dtype
83114
84115 y_ref .backward (g )
85116 y_mx .backward (g )
@@ -112,7 +143,6 @@ def test_linear_eager(elem_dtype, bias, input_shape):
112143)
113144@pytest .mark .parametrize ("mkn" , [(128 , 256 , 512 ), (256 , 512 , 128 ), (512 , 128 , 256 )])
114145def test_linear_eager_emulated_vs_real_gemm (recipe_name , mkn ):
115- M , K , N = 128 , 128 , 128
116146 M , K , N = mkn
117147
118148 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ).requires_grad_ ()
@@ -143,9 +173,9 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
143173 y_sqnr = compute_error (y_real , y_emulated )
144174 w_sqnr = compute_error (m_real [0 ].weight .grad , m_emulated [0 ].weight .grad )
145175 g_sqnr = compute_error (x_copy .grad , x .grad )
146- assert y_sqnr > 100 .0 , f"y_sqnr { y_sqnr } too low!"
147- assert w_sqnr > 100 .0 , f"w_sqnr { w_sqnr } too low!"
148- assert g_sqnr > 100 .0 , f"g_sqnr { g_sqnr } too low!"
176+ assert y_sqnr > 90 .0 , f"y_sqnr { y_sqnr } too low!"
177+ assert w_sqnr > 90 .0 , f"w_sqnr { w_sqnr } too low!"
178+ assert g_sqnr > 90 .0 , f"g_sqnr { g_sqnr } too low!"
149179
150180
151181# TODO(future): enable compile support
@@ -169,6 +199,7 @@ def test_activation_checkpointing():
169199
170200
171201@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
202+ @pytest .mark .parametrize ("hp_dtype" , [torch .float32 , torch .bfloat16 ])
172203@pytest .mark .parametrize (
173204 "recipe_name" ,
174205 [
@@ -182,7 +213,8 @@ def test_activation_checkpointing():
182213@pytest .mark .parametrize ("bias" , [False , True ])
183214# TODO(future PR): figure out why torch.compile does not match eager when
184215# autocast is on
185- def test_linear_compile (recipe_name , bias ):
216+ @pytest .mark .parametrize ("use_fp8_dim1_cast_triton_kernel" , [False , True ])
217+ def test_linear_compile (hp_dtype , recipe_name , bias , use_fp8_dim1_cast_triton_kernel ):
186218 """
187219 Verify that compile does not change numerics of MX linear fw + bw
188220 """
@@ -198,20 +230,36 @@ def test_linear_compile(recipe_name, bias):
198230 # TODO(future PR): fix this, things are clearly broken with bias=True
199231 pytest .skip ("this test is broken for non-emulated recipes with bias=True" )
200232
233+ if use_fp8_dim1_cast_triton_kernel :
234+ if recipe_name not in ("mxfp8_emulated" , "mxfp8_cublas" , "mxfp8_cutlass" ):
235+ pytest .skip ("unsupported configuration" )
236+ if not is_sm_at_least_89 ():
237+ pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
238+ if hp_dtype != torch .bfloat16 :
239+ pytest .skip ("unsupported configuration" )
240+
241+ if hp_dtype == torch .bfloat16 and recipe_name != "mxfp8_cublas" :
242+ # TODO(future PR): properly enable float32 + bfloat16 for every
243+ # recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
244+ # if the underlying gemm kernel only supports bf16 output)
245+ pytest .skip ("unsupported configuration" )
246+
201247 M , K , N = 128 , 256 , 512
202248 input_shape = (M , K )
203249 grad_shape = (M , N )
204250 m_mx = nn .Sequential (
205- nn .Linear (K , N , bias = bias , device = "cuda" ),
251+ nn .Linear (K , N , bias = bias , device = "cuda" , dtype = hp_dtype ),
206252 )
207253 config = MXLinearConfig .from_recipe_name (recipe_name )
254+ config .use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
255+
208256 swap_linear_with_mx_linear (m_mx , config = config )
209257 m_mx_c = copy .deepcopy (m_mx )
210258 m_mx_c = torch .compile (m_mx_c , fullgraph = True , backend = "inductor" )
211259
212- x_ref = torch .randn (* input_shape , device = "cuda" ).requires_grad_ ()
260+ x_ref = torch .randn (* input_shape , device = "cuda" , dtype = hp_dtype ).requires_grad_ ()
213261 x = copy .deepcopy (x_ref )
214- g = torch .randn (* grad_shape , device = "cuda" )
262+ g = torch .randn (* grad_shape , device = "cuda" , dtype = hp_dtype )
215263
216264 y_ref = m_mx (x_ref )
217265 y = m_mx_c (x )
@@ -283,7 +331,7 @@ def test_inference_compile_simple(elem_dtype):
283331 if elem_dtype is torch .float8_e4m3fn :
284332 assert sqnr >= 20.0
285333 else :
286- assert sqnr >= 13 .5
334+ assert sqnr >= 11 .5
287335
288336
289337@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
0 commit comments