File tree Expand file tree Collapse file tree 4 files changed +54
-0
lines changed
test/prototype/mx_formats
torchao/prototype/mx_formats Expand file tree Collapse file tree 4 files changed +54
-0
lines changed Original file line number Diff line number Diff line change @@ -401,3 +401,21 @@ def test_filter_fn():
401401 swap_linear_with_mx_inference_linear (m2 , config = config , filter_fn = filter_fn ) # noqa: E501
402402 assert type (m2 [0 ]) == MXInferenceLinear
403403 assert type (m2 [1 ]) == torch .nn .Linear
404+
405+
406+ def test_training_print_str ():
407+ m = nn .Sequential (nn .Linear (32 , 32 ))
408+ config = MXLinearConfig ()
409+ swap_linear_with_mx_linear (m , config = config )
410+ s = str (m )
411+ assert "bl_sz=32" in s
412+ assert "kernel=emulated" in s
413+
414+
415+ def test_inference_print_str ():
416+ m = nn .Sequential (nn .Linear (32 , 32 ))
417+ config = MXLinearConfig ()
418+ swap_linear_with_mx_inference_linear (m , config = config )
419+ s = str (m )
420+ assert "bl_sz=32" in s
421+ assert "kernel=emulated" in s
Original file line number Diff line number Diff line change 1212
1313from torchao .prototype .mx_formats .constants import (
1414 DTYPE_FP4 ,
15+ DTYPE_TO_SHORT_STR ,
1516 SUPPORTED_ELEM_DTYPES ,
1617)
1718
@@ -143,3 +144,22 @@ def from_recipe_name(
143144 )
144145 else :
145146 raise AssertionError (f"unknown recipe_name { recipe_name } " )
147+
148+ def short_str (self ) -> str :
149+ """
150+ Returns a concise representation of the current config.
151+ """
152+ s = f"bl_sz={ self .block_size } , lp_dtype={ DTYPE_TO_SHORT_STR [self .elem_dtype ]} "
153+ if self .elem_dtype_weight_override is not None :
154+ s += (
155+ f", lp_w_override={ DTYPE_TO_SHORT_STR [self .elem_dtype_weight_override ]} "
156+ )
157+ if self .elem_dtype_grad_output_override is not None :
158+ s += f", lp_go_override={ DTYPE_TO_SHORT_STR [self .elem_dtype_grad_output_override ]} "
159+ s += f", kernel={ self .gemm_kernel_choice .value } "
160+ if self .use_fp8_dim1_cast_triton_kernel :
161+ s += ", use_fp8_dim1_cast_triton_kernel=True"
162+ if self .use_fp4_custom_triton_dequant_kernel :
163+ s += ", use_fp4_custom_triton_dequant_kernel=True"
164+ # TODO(future PR): split training from inference and add fp6 here
165+ return s
Original file line number Diff line number Diff line change 2222 DTYPE_FP4 ,
2323]
2424
25+ DTYPE_TO_SHORT_STR = {
26+ torch .float8_e4m3fn : "f8e4m3" ,
27+ torch .float8_e5m2 : "f8e5m2" ,
28+ DTYPE_FP6_E2M3 : "f6e2m3" ,
29+ DTYPE_FP6_E3M2 : "f6e3m2" ,
30+ DTYPE_FP4 : "f4e2m1" ,
31+ }
32+
2533F8E4M3_MAX = torch .finfo (torch .float8_e4m3fn ).max # 448.0
2634F8E5M2_MAX = torch .finfo (torch .float8_e5m2 ).max # 57344.0
2735
Original file line number Diff line number Diff line change @@ -213,6 +213,10 @@ def forward(self, x):
213213 y = y + self .bias
214214 return y
215215
216+ def extra_repr (self ):
217+ s = f"{ super ().extra_repr ()} , { self .config .short_str ()} "
218+ return s
219+
216220
217221class MXInferenceLinear (torch .nn .Linear ):
218222 """
@@ -255,6 +259,10 @@ def forward(self, x):
255259 y = F .linear (x , w_hp , self .bias )
256260 return y
257261
262+ def extra_repr (self ):
263+ s = f"{ super ().extra_repr ()} , { self .config .short_str ()} "
264+ return s
265+
258266
259267def replace_with_custom_fn_if_matches_filter (
260268 model , replacement_fn , filter_fn , cur_fqn = ""
You can’t perform that action at this time.
0 commit comments