Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,6 @@ class BackendInfo:
"act_dynamic",
]

MX_TENSOR_DATA_TYPES = [
"mx_fp",
"mx_fp_rceil",
]


def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None):
if out_feature_multiplier is None:
Expand Down Expand Up @@ -235,13 +230,13 @@ def fp8_static_scheme_checker(
packing_format=LLM_COMPRESSOR_FORMAT,
sym=[True],
compute_dtype=["float32", "float16", "bfloat16"],
data_type=MX_TENSOR_DATA_TYPES,
data_type=["mx_fp", "max_fp_rceil"],
group_size=[32],
bits=[8],
act_bits=[8],
act_group_size=[32],
act_sym=[True],
act_data_type=MX_TENSOR_DATA_TYPES,
act_data_type=["mx_fp_rceil"],
act_dynamic=[True],
priority=0,
checkers=[feature_multiply_checker_32],
Expand All @@ -255,13 +250,13 @@ def fp8_static_scheme_checker(
packing_format=LLM_COMPRESSOR_FORMAT,
sym=[True],
compute_dtype=["float32", "float16", "bfloat16"],
data_type=MX_TENSOR_DATA_TYPES,
data_type=["mx_fp"],
group_size=[32],
bits=[4],
act_bits=[4],
act_group_size=[32],
act_sym=[True],
act_data_type=MX_TENSOR_DATA_TYPES,
act_data_type=["mx_fp_rceil"],
act_dynamic=[True],
priority=0,
checkers=[feature_multiply_checker_32],
Expand Down
8 changes: 0 additions & 8 deletions auto_round/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,3 @@ def decorator(test_func: Callable) -> Callable:
return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func)

return decorator


def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
"""Check if the model contains a specific module type."""
for _, module in model.named_modules():
if isinstance(module, target_module_type):
return True
return False
69 changes: 0 additions & 69 deletions test/test_cpu/test_mxfp_save_load.py

This file was deleted.

9 changes: 8 additions & 1 deletion test/test_cuda/test_mxfp_and_nvfp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from auto_round.experimental import qmodules as ar_qmodules
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
from auto_round.testing_utils import has_module

testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value]
QMODULE_MAPPING = {
Expand All @@ -20,6 +19,14 @@
}


def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
"""Check if the model contains a specific module type."""
for _, module in model.named_modules():
if isinstance(module, target_module_type):
return True
return False


@pytest.mark.parametrize("scheme", testing_schemes)
@torch.inference_mode()
def test_e2e_quant_and_infer(scheme):
Expand Down
Loading