Skip to content

Commit 46812de

Browse files
authored
Revert "Extend mxfp loading dtypes (#907)" (#915)
This reverts commit 0c2619c.
1 parent 1d91207 commit 46812de

File tree

4 files changed

+12
-87
lines changed

4 files changed

+12
-87
lines changed

auto_round/inference/backend.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,6 @@ class BackendInfo:
107107
"act_dynamic",
108108
]
109109

110-
MX_TENSOR_DATA_TYPES = [
111-
"mx_fp",
112-
"mx_fp_rceil",
113-
]
114-
115110

116111
def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None):
117112
if out_feature_multiplier is None:
@@ -235,13 +230,13 @@ def fp8_static_scheme_checker(
235230
packing_format=LLM_COMPRESSOR_FORMAT,
236231
sym=[True],
237232
compute_dtype=["float32", "float16", "bfloat16"],
238-
data_type=MX_TENSOR_DATA_TYPES,
233+
data_type=["mx_fp", "max_fp_rceil"],
239234
group_size=[32],
240235
bits=[8],
241236
act_bits=[8],
242237
act_group_size=[32],
243238
act_sym=[True],
244-
act_data_type=MX_TENSOR_DATA_TYPES,
239+
act_data_type=["mx_fp_rceil"],
245240
act_dynamic=[True],
246241
priority=0,
247242
checkers=[feature_multiply_checker_32],
@@ -255,13 +250,13 @@ def fp8_static_scheme_checker(
255250
packing_format=LLM_COMPRESSOR_FORMAT,
256251
sym=[True],
257252
compute_dtype=["float32", "float16", "bfloat16"],
258-
data_type=MX_TENSOR_DATA_TYPES,
253+
data_type=["mx_fp"],
259254
group_size=[32],
260255
bits=[4],
261256
act_bits=[4],
262257
act_group_size=[32],
263258
act_sym=[True],
264-
act_data_type=MX_TENSOR_DATA_TYPES,
259+
act_data_type=["mx_fp_rceil"],
265260
act_dynamic=[True],
266261
priority=0,
267262
checkers=[feature_multiply_checker_32],

auto_round/testing_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,3 @@ def decorator(test_func: Callable) -> Callable:
268268
return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func)
269269

270270
return decorator
271-
272-
273-
def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
274-
"""Check if the model contains a specific module type."""
275-
for _, module in model.named_modules():
276-
if isinstance(module, target_module_type):
277-
return True
278-
return False

test/test_cpu/test_mxfp_save_load.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

test/test_cuda/test_mxfp_and_nvfp_quant.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from auto_round.experimental import qmodules as ar_qmodules
1111
from auto_round.export.export_to_autoround import AutoRoundFormat
1212
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
13-
from auto_round.testing_utils import has_module
1413

1514
testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value]
1615
QMODULE_MAPPING = {
@@ -20,6 +19,14 @@
2019
}
2120

2221

22+
def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
23+
"""Check if the model contains a specific module type."""
24+
for _, module in model.named_modules():
25+
if isinstance(module, target_module_type):
26+
return True
27+
return False
28+
29+
2330
@pytest.mark.parametrize("scheme", testing_schemes)
2431
@torch.inference_mode()
2532
def test_e2e_quant_and_infer(scheme):

0 commit comments

Comments
 (0)