Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions auto_round/inference/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class BackendInfo:
"act_dynamic",
]

MX_TENSOR_DATA_TYPES = [
"mx_fp",
"mx_fp_rceil",
]


def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None):
if out_feature_multiplier is None:
Expand Down Expand Up @@ -230,13 +235,13 @@ def fp8_static_scheme_checker(
packing_format=LLM_COMPRESSOR_FORMAT,
sym=[True],
compute_dtype=["float32", "float16", "bfloat16"],
data_type=["mx_fp", "max_fp_rceil"],
data_type=MX_TENSOR_DATA_TYPES,
group_size=[32],
bits=[8],
act_bits=[8],
act_group_size=[32],
act_sym=[True],
act_data_type=["mx_fp_rceil"],
act_data_type=MX_TENSOR_DATA_TYPES,
act_dynamic=[True],
priority=0,
checkers=[feature_multiply_checker_32],
Expand All @@ -250,13 +255,13 @@ def fp8_static_scheme_checker(
packing_format=LLM_COMPRESSOR_FORMAT,
sym=[True],
compute_dtype=["float32", "float16", "bfloat16"],
data_type=["mx_fp"],
data_type=MX_TENSOR_DATA_TYPES,
group_size=[32],
bits=[4],
act_bits=[4],
act_group_size=[32],
act_sym=[True],
act_data_type=["mx_fp_rceil"],
act_data_type=MX_TENSOR_DATA_TYPES,
act_dynamic=[True],
priority=0,
checkers=[feature_multiply_checker_32],
Expand Down
8 changes: 8 additions & 0 deletions auto_round/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,11 @@ def decorator(test_func: Callable) -> Callable:
return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func)

return decorator


def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
"""Check if the model contains a specific module type."""
for _, module in model.named_modules():
if isinstance(module, target_module_type):
return True
return False
69 changes: 69 additions & 0 deletions test/test_cpu/test_mxfp_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import shutil
import tempfile

import pytest
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM

from auto_round import AutoRound
from auto_round import schemes as ar_schemes
from auto_round.experimental import qmodules as ar_qmodules
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
from auto_round.inference.backend import MX_TENSOR_DATA_TYPES
from auto_round.testing_utils import has_module

testing_scheme_name_lst = [
AutoRoundFormat.MXFP8.value,
AutoRoundFormat.MXFP4.value,
]
QMODULE_MAPPING = {
AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear,
AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear,
}
SCHEMES_MAPPING = {
AutoRoundFormat.MXFP8.value: ar_schemes.MXFP8,
AutoRoundFormat.MXFP4.value: ar_schemes.MXFP4,
}


@pytest.mark.parametrize("scheme_name", testing_scheme_name_lst)
@pytest.mark.parametrize("weight_data_type", MX_TENSOR_DATA_TYPES)
@pytest.mark.parametrize("act_data_type", MX_TENSOR_DATA_TYPES)
@torch.inference_mode()
def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type):
# Use a temporary directory for saving the quantized model
with tempfile.TemporaryDirectory() as temp_dir:
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct"
config = AutoConfig.from_pretrained(model_name)
config.num_hidden_layers = 2 # Use a smaller model for testing

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = Qwen2ForCausalLM(config)
scheme = SCHEMES_MAPPING[scheme_name]
scheme.data_type = weight_data_type
scheme.act_data_type = act_data_type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after #903 is merged, it's better to add a mixed bits ut and verify inference

# Initialize AutoRound for quantization
autoround = AutoRound(
model,
tokenizer,
scheme=scheme,
iters=0,
nsamples=2,
)

# Quantize and save the model to the temporary directory
quantized_model_path = f"{temp_dir}/tmp_autoround"
autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path)

# Perform inference with the quantized model
model = AutoModelForCausalLM.from_pretrained(
quantized_model_path,
torch_dtype="auto",
)
model.eval()
assert has_module(
model, QMODULE_MAPPING[scheme_name]
), f"Expected {QMODULE_MAPPING[scheme_name].__name__} in the model."
9 changes: 1 addition & 8 deletions test/test_cuda/test_mxfp_and_nvfp_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from auto_round.experimental import qmodules as ar_qmodules
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp
from auto_round.testing_utils import has_module

testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value]
QMODULE_MAPPING = {
Expand All @@ -19,14 +20,6 @@
}


def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool:
"""Check if the model contains a specific module type."""
for _, module in model.named_modules():
if isinstance(module, target_module_type):
return True
return False


@pytest.mark.parametrize("scheme", testing_schemes)
@torch.inference_mode()
def test_e2e_quant_and_infer(scheme):
Expand Down
Loading