Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2b1577b
enable dynamic quantization config saving
WeiweiZhang1 Sep 16, 2025
56a2218
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2025
db99785
fixtypo
WeiweiZhang1 Sep 16, 2025
81e8086
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2025
d5b9a46
Merge branch 'main' into enable_dynamic_quantization_config_saving
WeiweiZhang1 Sep 24, 2025
4e58090
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
c75ebdc
rebase code, refine config saving
WeiweiZhang1 Sep 24, 2025
ae20df7
Merge branch 'main' into enable_dynamic_quantization_config_saving
WeiweiZhang1 Sep 24, 2025
21ff4b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
b97f3fc
refine ut
WeiweiZhang1 Sep 24, 2025
be7af05
Merge branch 'enable_dynamic_quantization_config_saving' of https://g…
WeiweiZhang1 Sep 24, 2025
b91bf20
fix UT
WeiweiZhang1 Sep 24, 2025
cd5c693
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2025
98486d4
Merge branch 'main' into enable_dynamic_quantization_config_saving
WeiweiZhang1 Oct 16, 2025
2d45996
enable hf loading for regex, add UTs
WeiweiZhang1 Oct 16, 2025
a8bdf81
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 16, 2025
abc4fd0
Merge branch 'main' into enable_dynamic_quantization_config_saving
WeiweiZhang1 Oct 22, 2025
8a668b2
refine export, enhance gptq UT
WeiweiZhang1 Oct 22, 2025
8085d1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def _gen_auto_scheme(
"Please save the model using the `fake` format for now."
)

layer_config, self.has_qlayer_outside_block = set_layer_config(
layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config(
self.model,
self.layer_config,
self.scheme,
Expand Down Expand Up @@ -1653,7 +1653,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
enable_gguf_official_mixed = True
else:
enable_gguf_official_mixed = False
self.layer_config, self.has_qlayer_outside_block = set_layer_config(
self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config(
self.model,
self.layer_config,
self.scheme,
Expand Down Expand Up @@ -2937,6 +2937,8 @@ def save_quantized(
"Support for exporting activation quantization is limited. "
"Please ensure that your configuration is supported."
)
# if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
# format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)):
format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}")
if format == "llm_compressor" and is_static_wfp8afp8(self):
Expand Down Expand Up @@ -2985,6 +2987,7 @@ def save_quantized(
"act_data_type",
"super_bits",
"super_group_size",
"regex_config",
]
if isinstance(self.dataset, str):
serialization_keys.append("dataset")
Expand Down
164 changes: 131 additions & 33 deletions auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import json
import os
from concurrent.futures import ThreadPoolExecutor
from dataclasses import fields
from typing import Any, Dict

import threadpoolctl as tctl

Expand Down Expand Up @@ -48,16 +50,28 @@

import auto_round.export.export_to_autogptq.qlinear_triton
from auto_round.export.utils import save_model
from auto_round.schemes import QuantizationScheme

GPTQ_REQUIRED_CONFIG_KEYS = (
"bits",
"group_size",
"sym",
)

from auto_round.logger import logger
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
check_start_with_block_name,
check_to_quantized,
copy_python_files_from_model_cache,
filter_quantization_config,
get_autogptq_packing_qlinear,
get_block_names,
get_module,
json_serialize,
matches_any_regex,
set_module,
to_standard_regex,
)

BLOCK_PATTERNS = [ ## copy from transformers optimum
Expand All @@ -66,6 +80,54 @@
"gpt_neox.layers",
"model.layers",
]
from auto_round.export.export_to_autoround.utils import check_neq_config


def convert_to_autogptq_dynamic(regex_config: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""
Convert AutoRound-style regex_config into AutoGPTQ-style QuantizerConfig.dynamic.

Rules:
- bits < 16 -> quantize -> positive match `+:regex`
- bits == 16 -> skip quantize -> negative match `-:regex`
"""
converted = {}
for name, cfg in regex_config.items():
bits = cfg.get("bits")
regex = to_standard_regex(name)

if bits is None:
continue # ignore invalid entries
elif bits < 16:
converted[f"+:{regex}"] = {"bits": bits}
for key in GPTQ_REQUIRED_CONFIG_KEYS: # only save keys gptq supported
converted[f"+:{regex}"][key] = regex_config[name][key]
else:
# skip quantization
converted[f"-:{regex}"] = {}
return converted


def convert_from_autogptq_dynamic(dynamic_config: dict) -> dict:
"""
Convert AutoGPTQ-style QuantizerConfig.dynamic into AutoRound-style extra_config.

Rules:
- '+:regex' => quantize => keep bits and other quantization keys
- '-:regex' => skip quantize => set bits to 16 (FP16 passthrough)
"""
converted = {}
for name, cfg in dynamic_config.items():
# Strip the +: or -:
if name.startswith("+:"):
regex = name[2:]
# keep all config fields (bits, group_size, sym, etc.)
converted[regex] = dict(cfg)
elif name.startswith("-:"):
regex = name[2:]
# mark skipped layers with bits=16
converted[regex] = {"bits": 16, "act_bits": 16}
return converted


def pack_layer(name, model, backend, device=None):
Expand Down Expand Up @@ -132,58 +194,93 @@ def pack_layer(name, model, backend, device=None):
def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exllamav2", **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""

# --- 1️⃣ Extract inputs & configs ---
model = kwargs["model"]
safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"]
quant_block_list = kwargs.get("quant_block_list", get_block_names(model))
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
device = kwargs.get("device", None)
image_processor = kwargs.get("image_processor", None)
if output_dir is not None and os.path.exists(output_dir):
logger.warning(f"{output_dir} already exists, this may cause model conflict")
if output_dir is not None and tokenizer is not None and hasattr(tokenizer, "save_pretrained"):
tokenizer.save_pretrained(output_dir)
if output_dir is not None and processor is not None:
processor.save_pretrained(output_dir)
if output_dir is not None and image_processor is not None:
image_processor.save_pretrained(output_dir)
##check module quantized in block, this may have bug for mixed precision quantization
quantization_config = kwargs["serialization_dict"]
layer_config = kwargs["layer_config"]
quant_block_list = kwargs.get("quant_block_list", get_block_names(model))
tokenizer = kwargs.get("tokenizer")
processor = kwargs.get("processor")
image_processor = kwargs.get("image_processor")
device = kwargs.get("device")
safe_serialization = kwargs.get("safe_serialization", True)

# --- Save metadata (tokenizer, processor, etc.) ---
if output_dir:
if os.path.exists(output_dir):
logger.warning(f"{output_dir} already exists, may cause overwrite conflicts.")
for comp in (tokenizer, processor, image_processor):
if comp is not None and hasattr(comp, "save_pretrained"):
comp.save_pretrained(output_dir)

# --- Handle quantization structure ---
all_blocks = quant_block_list
flattened_list = [item for sublist in all_blocks for item in sublist]
common_prefix = os.path.commonprefix(flattened_list).rstrip(".")
if common_prefix not in BLOCK_PATTERNS:
logger.error("auto-gptq format may not support loading this quantized model")
flattened = [x for sub in all_blocks for x in sub]
common_prefix = os.path.commonprefix(flattened).rstrip(".")

if "BLOCK_PATTERNS" in kwargs and common_prefix not in kwargs["BLOCK_PATTERNS"]:
logger.error(f"Unsupported block prefix '{common_prefix}' for AutoGPTQ format.")
quantization_config["block_name_to_quantize"] = common_prefix
quantization_config.pop("to_quant_block_names", None)

## as layers maybe already packed, we need to check in layer_config
layer_config = kwargs["layer_config"]
# --- Build per-layer dynamic overrides ---
regex_config = quantization_config.pop("regex_config", {})
block_name_to_quantize = quantization_config.get("block_name_to_quantize")
extra_config = {}
lm_head_quantized = False
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for layer_name, cfg in layer_config.items():
bits = cfg.get("bits", 16)
in_blocks = cfg.get("in_blocks", False)
# Handle non-block layers (e.g., LM head)
if not in_blocks and bits <= 8:
lm_head_quantized = True
extra_config[layer_name] = {k: cfg[k] for k in GPTQ_REQUIRED_CONFIG_KEYS}
continue
# Handle block layers
if in_blocks or (block_name_to_quantize and check_start_with_block_name(layer_name, block_name_to_quantize)):
neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys})
if neq_keys:
if matches_any_regex(layer_name, regex_config):
continue
extra_config[layer_name] = {k: cfg[k] for k in GPTQ_REQUIRED_CONFIG_KEYS}

# --- Merge regex_config + extra_config into GPTQ dynamic config ---
dynamic = {}
if regex_config:
dynamic.update(convert_to_autogptq_dynamic(regex_config))
if extra_config:
dynamic.update(convert_to_autogptq_dynamic(extra_config))
if dynamic:
quantization_config["dynamic"] = dynamic

# --- Block-wise quantization verification ---
for n, m in model.named_modules():
m.tmp_name = n

all_to_quantized = True
modules_in_block_to_quantize = []
for block_names in all_blocks:
first_block = get_module(model, block_names[0])
for n, m in first_block.named_modules():
if m.tmp_name not in layer_config.keys():
continue
if not check_to_quantized(layer_config[m.tmp_name]):
all_to_quantized = False
else:
modules_in_block_to_quantize.append(n)
modules_in_block_to_quantize = [modules_in_block_to_quantize]
if not dynamic: # Only uniform precision
for block_names in all_blocks:
first_block = get_module(model, block_names[0])
for n, m in first_block.named_modules():
if m.tmp_name not in layer_config:
continue
if not check_to_quantized(layer_config[m.tmp_name]):
all_to_quantized = False
else:
modules_in_block_to_quantize.append(n)
modules_in_block_to_quantize = [modules_in_block_to_quantize]

if all_to_quantized:
modules_in_block_to_quantize = None

for n, m in model.named_modules():
for _, m in model.named_modules():
delattr(m, "tmp_name")

if not inplace:
model = copy.deepcopy(model.to("cpu"))

layer_config = kwargs["layer_config"]
names = list(layer_config.keys())
max_workers = 1
if not torch.cuda.is_available() and not torch.xpu.is_available():
Expand All @@ -202,6 +299,7 @@ def wrapper(name):
pass
if output_dir is None:
return model
quantization_config["lm_head"] = lm_head_quantized
quantization_config["provider"] = "auto-round"
quantization_config["quant_method"] = "gptq"
quantization_config.pop("dataset", None) ## pile-10k is not supported in gptq
Expand Down
8 changes: 8 additions & 0 deletions auto_round/export/export_to_autoround/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
is_nv_fp,
is_standard_fp,
set_module,
to_standard_regex,
)


Expand Down Expand Up @@ -340,8 +341,15 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
if cfg[key] is not None:
extra_config[layer_name][key] = cfg[key]

regex_config = quantization_config.pop("regex_config")
if regex_config is not None:
for name in regex_config.keys():
regex_name = to_standard_regex(name)
extra_config[regex_name] = {**{k: regex_config[name][k] for k in scheme_keys}}

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config

names = list(layer_config.keys())
max_workers = 1
if not torch.cuda.is_available() and not torch.xpu.is_available():
Expand Down
7 changes: 7 additions & 0 deletions auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_nv_fp,
set_amax_for_all_moe_layers,
set_module,
to_standard_regex,
)
from auto_round.wrapper import WrapperWALayer

Expand Down Expand Up @@ -211,6 +212,12 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
if cfg[key] is not None:
extra_config[layer_name][key] = cfg[key]

regex_config = quantization_config.pop("regex_config")
if regex_config is not None:
for name in regex_config.keys():
regex_name = to_standard_regex(name)
extra_config[regex_name] = {**{k: regex_config[name][k] for k in scheme_keys}}

if len(extra_config) > 0:
quantization_config["extra_config"] = extra_config
names = list(layer_config.keys())
Expand Down
8 changes: 7 additions & 1 deletion auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def wrapper(name):
return model

quantization_config = kwargs["serialization_dict"]
regex_config = quantization_config.pop("regex_config", {}) # awq do not support mixed bits config saving

if output_dir is None:
return compressed_model
Expand All @@ -145,11 +146,16 @@ def wrapper(name):
for key in layer_config.keys():
if not check_to_quantized(layer_config[key]) and not any(name in key for name in modules_to_not_convert):
modules_to_not_convert.append(key)
for key, cfg in regex_config.items():
bits = cfg.get("bits")
if bits > 8: # save fp_layer regexs
modules_to_not_convert.append(key)

quantization_config["provider"] = "auto-round"
quantization_config["quant_method"] = "awq"
quantization_config["zero_point"] = not quantization_config["sym"]
quantization_config["version"] = "gemm"
quantization_config["modules_to_not_convert"] = modules_to_not_convert
quantization_config["modules_to_not_convert"] = list(dict.fromkeys(modules_to_not_convert))
##check module quantized in block, this may have bug for mixed precision quantization
filter_quantization_config(quantization_config)
if hasattr(compressed_model, "config"):
Expand Down
14 changes: 6 additions & 8 deletions auto_round/export/export_to_llmcompressor/export_to_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tqdm import tqdm

from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear
from auto_round.export.export_to_llmcompressor.utils import generate_ignore_regex_list
from auto_round.export.utils import save_model
from auto_round.logger import logger
from auto_round.utils import (
Expand Down Expand Up @@ -114,9 +115,8 @@ def pack_layer(name, model, backend, device=None):
scale = layer.scale
global_scale = getattr(layer, "weight_global_scale", None)
input_global_scale = getattr(layer, "input_global_scale", None)
# zero = layer.zp
# zero = layer.zp # no zeros to handle, as mxfp not support asym quantization
qlayer.pack(layer, scale, global_scale=global_scale, input_global_scale=input_global_scale, device=device)
## no zeros to handle, as mxfp not support asym quantization
qlayer.to(orig_device)


Expand Down Expand Up @@ -155,6 +155,9 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
device = kwargs.get("device", None)
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)
ar_quantization_config = kwargs["serialization_dict"]
regex_config = ar_quantization_config.pop("regex_config")
layer_config = kwargs["layer_config"]
extra_config = {}

if act_bits <= 8:
Expand Down Expand Up @@ -199,12 +202,7 @@ def wrapper(name):
for _ in executor.map(wrapper, names):
pass

# TODO fix the ignore re match issue, compile with fp8 & int8 config
ignore = ["lm_head"]
for layer_name in layer_config:
if layer_config[layer_name]["bits"] > 8: ## find ignore layers
ignore.append(layer_name)
ignore = list(set(ignore))
ignore = generate_ignore_regex_list(regex_config=regex_config, layer_config=layer_config)

# get llm-compressor format config
check_compressed_tensors_supported()
Expand Down
Loading