Skip to content

Commit da55e27

Browse files
authored
optimize rtn for int woq (#924)
1 parent 5bb16b0 commit da55e27

File tree

9 files changed

+171
-54
lines changed

9 files changed

+171
-54
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions,
2727

2828

2929
## 🆕 What's New
30-
[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
30+
[2025/10] We proposed a fast algorithm to generate **mixed bits/datatypes** schemes in minutes. Please
3131
refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions.
3232

33-
[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
33+
[2025/09] AutoRound now includes experimental support for the **mxfp4 and nvfp4 dtypes**. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
3434
. We currently recommend exporting to the LLM-Compressor format.
3535

36-
[2025/08] AutoRound now provides experimental support for an improved INT2 algorithm via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md)
36+
[2025/08] AutoRound now provides experimental support for **an improved INT2 algorithm** via `--enable_alg_ext`. See this [documentation](./docs/alg_202508.md)
3737
for some accuracy results.
3838

3939
[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
@@ -67,7 +67,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De
6767
**Affordable Quantization Cost**
6868
Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs)
6969

70-
**Fast mixed bits/data-types scheme generation**
70+
**Fast Mixed Bits/Dtypes Scheme Generation**
7171
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme).
7272

7373
**10+ VLMs Support**
@@ -76,8 +76,8 @@ Out-of-the-box quantization for 10+ vision-language models [example models](http
7676
**Layerwise Mixed Bits Quantization**
7777
Assign different bits per layer for fine-grained accuracy/performance trade-offs. Details are shown in [mixed bits quantization](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#mixed-bits-usage)
7878

79-
**Round-to-Nearest Mode**
80-
Use `--iters 0` for fast, calibration-free quantization with some accuracy drop. Details are shown in [rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#rtn-mode)
79+
**Optimized Round-to-Nearest Mode**
80+
Use `--iters 0` for fast, calibration-free quantization with some accuracy drop for 4 bits. Details are shown in [opt_rtn mode](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#opt-rtn-mode)
8181

8282
**Multiple Recipes**
8383
Choose from `auto-round-best`, `auto-round`, and `auto-round-light` to suit your needs. Details are shown in [quantization recipes](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#recipe-recommendation)

auto_round/compressors/base.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ def _quant_rtn_with_imatrix(self, all_to_quantized_module_names: list[str]) -> N
12441244
Returns:
12451245
None
12461246
"""
1247-
logger.info("start to compute imatrix for GGUF quantization")
1247+
logger.info("start to compute imatrix")
12481248

12491249
# Load dataset
12501250
from auto_round.calib_dataset import get_dataloader
@@ -1343,15 +1343,13 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
13431343
if _is_fp8_linear(m):
13441344
m = convert_fp8_layer_to_linear(m, self.amp_dtype)
13451345
set_module(self.model, name, m)
1346-
1347-
# Step 1: Use optimized RTN data type if available
1348-
if not self.disable_opt_rtn and not m.data_type.startswith("rtn_"):
1349-
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
1350-
1351-
rtn_dtype = "rtn_" + m.data_type
1352-
if rtn_dtype in QUANT_FUNC_WITH_DTYPE:
1353-
m.data_type = rtn_dtype
1354-
self.layer_config[name]["data_type"] = m.data_type
1346+
#
1347+
# # Step 1: Use optimized RTN data type if available
1348+
# if not self.disable_opt_rtn:
1349+
# rtn_data_type = self._check_rtn_dytpe(m.data_type, m.bits, m.sym)
1350+
# if rtn_data_type is not None:
1351+
# m.data_type = rtn_data_type
1352+
# self.layer_config[name]["data_type"] = m.data_type
13551353

13561354
# Step 2: Try quantization on GPU first, fall back to CPU if OOM
13571355
# if only export gguf, using gguf-packing instead of rtn
@@ -1367,6 +1365,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
13671365
enable_norm_bias_tuning=False,
13681366
enable_round_tuning=False,
13691367
enable_torch_compile=self.enable_torch_compile,
1368+
disable_opt_rtn=self.disable_opt_rtn,
13701369
)
13711370
m = m.unwrapper({})
13721371
m.to("cpu")
@@ -1457,7 +1456,14 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
14571456
self._quantize_embedding_layer()
14581457

14591458
self.model.to("cpu")
1459+
1460+
enable_imatrix = False
14601461
if has_gguf_k and not self.disable_opt_rtn:
1462+
enable_imatrix = True
1463+
if self.data_type == "int" and self.sym:
1464+
enable_imatrix = True
1465+
1466+
if enable_imatrix:
14611467
self._quant_rtn_with_imatrix(all_to_quantized_module_names)
14621468
elif self.act_bits <= 8 and check_need_act_calibration(
14631469
self.act_dynamic, self.act_data_type, self.act_bits
@@ -1800,8 +1806,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18001806
Returns:
18011807
None
18021808
"""
1803-
##TODO currently we take all the layers outside blocks as post block layers which is not optimal
1804-
## if there is no input for layer, we use rtn
1809+
# TODO currently we take all the layers outside blocks as post block layers which is not optimal
1810+
# if there is no input for layer, we use rtn
18051811

18061812
for layer_name in copy.deepcopy(layer_names):
18071813
if layer_name not in layer_inputs:
@@ -1815,17 +1821,14 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18151821
set_module(self.model, layer_name, new_layer)
18161822
layer = new_layer
18171823

1818-
if not self.disable_opt_rtn and "rtn_" + layer.data_type in QUANT_FUNC_WITH_DTYPE:
1819-
layer.data_type = "rtn_" + layer.data_type
1820-
logger.info("using optimized rtn method for quantizing %s", layer_name)
1821-
self.layer_config[layer_name]["data_type"] = layer.data_type
18221824
wrapper_layer = WrapperLinear(
18231825
layer,
18241826
enable_round_tuning=False,
18251827
enable_minmax_tuning=False,
18261828
enable_norm_bias_tuning=False,
18271829
enable_torch_compile=self.enable_torch_compile,
18281830
device=self.device,
1831+
disable_opt_rtn=self.disable_opt_rtn,
18291832
)
18301833
new_layer = wrapper_layer.unwrapper({})
18311834
set_module(self.model, layer_name, new_layer)

auto_round/data_type/int.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,75 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Union
1415

1516
import torch
1617

1718
from auto_round.data_type.register import register_dtype
1819
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
20+
from auto_round.utils import get_reciprocal
21+
22+
23+
def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor:
24+
nmax = pow(2, bits - 1)
25+
imax = abs(data).argmax(axis=-1, keepdims=True)
26+
group_max = torch.take_along_dim(data, imax, dim=-1)
27+
iscales = -nmax * get_reciprocal(group_max)
28+
scales = get_reciprocal(iscales)
29+
L = torch.round(1.0 * iscales * data).clip(-nmax, nmax - 1)
30+
if qw is None:
31+
qw = 1.0
32+
best_loss = torch.sum(((scales * L - data).to(torch.float32)) ** 2 * qw, dim=-1)
33+
for _is in range(-18 * 5, 18 * 5 + 1):
34+
if _is == 0:
35+
continue
36+
iscales = -(nmax - 0.01 * _is) * get_reciprocal(group_max)
37+
tmp_L = torch.round(iscales * data).clip(-nmax, nmax - 1)
38+
tmp_scales = get_reciprocal(iscales)
39+
loss = torch.sum(((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 * qw, dim=-1)
40+
replace_id = loss < best_loss
41+
scales[replace_id] = tmp_scales[replace_id]
42+
best_loss[replace_id] = loss[replace_id]
43+
return scales
44+
45+
46+
@register_dtype("rtn_int_sym")
47+
def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
48+
"""Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community
49+
50+
Args:
51+
tensor: Tensor containing the tensor to be quantized
52+
bits: Number of bits for quantization (e.g., 2, 3, 4, 8)
53+
group_size: Number of elements to share scale for quantization
54+
v: Rounding value perturbation
55+
min_scale: Minimum scale coefficient for tensor
56+
max_scale: Maximum scale coefficient for tensor
57+
tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None.
58+
tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None.
59+
scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import
60+
q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability
61+
62+
Returns:
63+
Quantized and de-quantized tensor, scale, zero-point
64+
"""
65+
66+
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
67+
maxq = 2 ** (bits - 1)
68+
if imatrix is None:
69+
imatrix = 1.0
70+
else:
71+
imatrix = imatrix.reshape(1, -1)
72+
73+
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
74+
imatrix = imatrix.reshape(tensor.shape)
75+
76+
scale = search_scales(tensor, bits, qw=imatrix)
77+
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
78+
int_w = round_ste(tensor / scale + v)
79+
q = torch.clamp(int_w, -maxq, maxq - 1)
80+
qdq_result = (scale * q).to(tensor.dtype)
81+
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
82+
return qdq_result, scale, maxq
1983

2084

2185
@register_dtype("int_sym")

auto_round/data_type/utils.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def revert_tensor_by_pad(data: torch.Tensor, orig_shape: tuple, pad_len: int):
8787
return data_new
8888

8989

90-
def get_quant_func(dtype, bits, sym):
90+
def get_quant_func(dtype: str, bits: int, sym: bool, disable_opt_rtn=False) -> tuple[callable, str]:
9191
"""Retrieve the quantization function based on data type, bit width, and symmetry.
9292
9393
This function returns the appropriate quantization function from the QUANT_FUNC_WITH_DTYPE
@@ -98,40 +98,38 @@ def get_quant_func(dtype, bits, sym):
9898
dtype (str): The data type for the quantization (e.g., 'int', 'mxfp4').
9999
bits (int): The bit width for the quantization (e.g., 2,4,8).
100100
sym (bool): A flag indicating whether the quantization is symmetric (True) or asymmetric (False).
101+
disable_opt_rtn(bool): whether to disable optimized rtn.
101102
102103
Returns:
103104
function: The quantization function corresponding to the specified parameters.
105+
str
104106
"""
105-
key = dtype
106-
if key in QUANT_FUNC_WITH_DTYPE.keys():
107-
return QUANT_FUNC_WITH_DTYPE[key], key
108107

109-
if sym:
110-
key = dtype + "_sym"
111-
else:
112-
key = dtype + "_asym"
108+
def pad_sym(data_type):
109+
if sym:
110+
data_sym = data_type + "_sym"
111+
else:
112+
data_sym = data_type + "_asym"
113+
return data_sym
113114

114-
if key in QUANT_FUNC_WITH_DTYPE.keys():
115-
return QUANT_FUNC_WITH_DTYPE[key], key
115+
def pad_bits(data_type):
116+
return data_type + str(bits)
116117

117-
##need to add bits and sym infos
118-
if sym:
119-
key = dtype + str(bits) + "_sym"
120-
else:
121-
key = dtype + str(bits) + "_asym"
118+
if not disable_opt_rtn:
119+
rtn_data_type = "rtn_" + dtype
120+
data_types = [rtn_data_type, pad_bits(rtn_data_type), pad_sym(rtn_data_type), pad_sym(pad_bits(rtn_data_type))]
121+
for data_type in data_types:
122+
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
122123

123-
if key in QUANT_FUNC_WITH_DTYPE.keys():
124-
return QUANT_FUNC_WITH_DTYPE[key], key
125-
126-
if sym:
127-
key = dtype + str(bits)
128-
else:
129-
key = dtype + str(bits)
124+
if data_type in QUANT_FUNC_WITH_DTYPE:
125+
return QUANT_FUNC_WITH_DTYPE[data_type], data_type
130126

131-
if key in QUANT_FUNC_WITH_DTYPE.keys():
132-
return QUANT_FUNC_WITH_DTYPE[key], key
127+
data_types = [dtype, pad_bits(dtype), pad_sym(dtype), pad_sym(pad_bits(dtype))]
128+
for data_type in data_types:
129+
from auto_round.data_type import QUANT_FUNC_WITH_DTYPE
133130

134-
raise ValueError(f"{dtype} is not supported")
131+
if data_type in QUANT_FUNC_WITH_DTYPE:
132+
return QUANT_FUNC_WITH_DTYPE[data_type], data_type
135133

136134

137135
def round_ste(x: torch.Tensor):

auto_round/export/export_to_autoround/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def pack_layer(layer_name, model, backend, device=None):
230230
zp = int(zp.flatten()[0])
231231

232232
qlayer.to("cpu")
233-
##force to float32 to be compatible with torch 2.0
233+
# Force to float32 to be compatible with torch 2.0
234234
sig = inspect.signature(qlayer.pack)
235235
param_count = len(sig.parameters)
236236
if param_count == 2:
@@ -296,7 +296,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
296296

297297
return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs)
298298

299-
##if using sym, we change to gptq sym kernel to avoid compiling from auto_round source
299+
# IF using sym, we change to gptq sym kernel to avoid compiling from auto_round source
300300
if (
301301
(kwargs.get("sym") is None or kwargs.get("sym"))
302302
and ("gptq" not in backend and "awq" not in backend)

auto_round/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,6 +2929,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
29292929
if name in all_module_names:
29302930
m = get_module(model, name)
29312931
if len(list(m.children())) == 0 and type(m) not in supported_types:
2932+
layer_config.pop(name)
29322933
logger.warning(f"{name} is not supported in current scheme, ignoring its setting in `layer_config`")
29332934
continue
29342935

auto_round/wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(
8080
device="cpu",
8181
enable_round_tuning=True,
8282
enable_torch_compile=False,
83+
disable_opt_rtn=True,
8384
**kwargs,
8485
):
8586
"""Initializes the WrapperLinear module.
@@ -92,6 +93,7 @@ def __init__(
9293
"""
9394
super(WrapperLinear, self).__init__()
9495
self.orig_layer = orig_layer
96+
self.disable_opt_rtn = disable_opt_rtn
9597
self.output_device = device
9698
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
9799
self.enable_minmax_tuning = enable_minmax_tuning
@@ -146,13 +148,15 @@ def _init_tuning_params_and_quant_func(self):
146148
self._init_params("min_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
147149
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
148150

149-
self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
151+
self.weight_quant_func, self.data_type = get_quant_func(
152+
orig_layer.data_type, orig_layer.bits, orig_layer.sym, self.disable_opt_rtn
153+
)
150154
if self.enable_torch_compile:
151155
self.weight_quant_func = compile_func(self.weight_quant_func, self.device)
152156

153157
if self.enable_act_quant:
154158
self.act_quant_func, self.act_data_type = get_quant_func(
155-
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
159+
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym, self.disable_opt_rtn
156160
)
157161
if self.enable_torch_compile:
158162
self.act_quant_func = compile_func(self.act_quant_func, self.device)

docs/opt_rtn.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
### 🧮 Evaluation Results (LM-Eval)
2+
For 2/3bit, we strongly recommend not using iter=0 except for GGUF:Q2_K_S which has a different quantization algorithm.
3+
4+
4BIT=W4A16
5+
3BIT=W3A16
6+
2BIT=W2A16G64
7+
8+
RTN mode
9+
10+
~~~bash
11+
auto-round --model xxx --disable_opt_rtn --iters 0
12+
~~~
13+
14+
OPT RTN mode
15+
16+
~~~bash
17+
auto-round --model xxx --iters 0
18+
~~~
19+
20+
21+
22+
| Model | RNT/OPT | AVG | HellaSwag | LAMBADA | MMLU | PIQA | WinoGrande |
23+
|--------------------------------|----------|---------|-----------|---------|--------|--------|------------|
24+
| **Meta-Llama-3.1-8B-Instruct** | RTN-4BIT | 0.69328 | 0.5896 | 0.7013 | 0.6538 | 0.7987 | 0.7230 |
25+
| | OPT-4BIT | 0.69560 | 0.5882 | 0.7074 | 0.6631 | 0.7916 | 0.7277 |
26+
| | RTN-3BIT | 0.64562 | 0.5410 | 0.6695 | 0.5449 | 0.7742 | 0.6985 |
27+
| | OPT-3BIT | 0.65970 | 0.5490 | 0.6893 | 0.5711 | 0.7677 | 0.7214 |
28+
| | RTN-2BIT | 0.33008 | 0.2918 | 0.0474 | 0.2321 | 0.5740 | 0.5051 |
29+
| | OPT-2BIT | 0.38908 | 0.3241 | 0.1560 | 0.2822 | 0.6235 | 0.5596 |
30+
| **Qwen2.5-7B-Instruct** | RTN-4BIT | 0.69560 | 0.6114 | 0.6713 | 0.7011 | 0.7878 | 0.7064 |
31+
| | OPT-4BIT | 0.70034 | 0.6143 | 0.6945 | 0.7115 | 0.7845 | 0.6969 |
32+
| | RTN-3BIT | 0.64144 | 0.5585 | 0.6092 | 0.6455 | 0.7476 | 0.6464 |
33+
| | OPT-3BIT | 0.66764 | 0.5756 | 0.7013 | 0.6597 | 0.7481 | 0.6535 |
34+
| | RTN-2BIT | 0.31856 | 0.2804 | 0.0351 | 0.2379 | 0.5256 | 0.5138 |
35+
| | OPT-2BIT | 0.45146 | 0.3645 | 0.2992 | 0.4043 | 0.6415 | 0.5478 |
36+
| **Qwen3-8B** | RTN-4BIT | 0.66240 | 0.5619 | 0.6150 | 0.7077 | 0.7573 | 0.6701 |
37+
| | OPT-4BIT | 0.66992 | 0.5619 | 0.6346 | 0.7102 | 0.7633 | 0.6796 |
38+
| | RTN-3BIT | 0.57322 | 0.4992 | 0.4260 | 0.6002 | 0.7361 | 0.6046 |
39+
| | OPT-3BIT | 0.63698 | 0.5226 | 0.5814 | 0.6718 | 0.7437 | 0.6654 |
40+
| | RTN-2BIT | 0.31150 | 0.2679 | 0.0041 | 0.2536 | 0.5283 | 0.5036 |
41+
| | OPT-2BIT | 0.44254 | 0.3749 | 0.2005 | 0.4202 | 0.6670 | 0.5501 |
42+
| **Qwen3-14B** | RTN-4BIT | 0.70448 | 0.5999 | 0.6511 | 0.7565 | 0.7998 | 0.7151 |
43+
| | OPT-4BIT | 0.70798 | 0.6031 | 0.6627 | 0.7534 | 0.8009 | 0.7198 |
44+
| | RTN-3BIT | 0.65876 | 0.5746 | 0.5467 | 0.7065 | 0.7628 | 0.7032 |
45+
| | OPT-3BIT | 0.68610 | 0.5683 | 0.6633 | 0.7258 | 0.7699 | 0.7032 |
46+
| | RTN-2BIT | 0.39398 | 0.3764 | 0.0607 | 0.3836 | 0.6480 | 0.5012 |
47+
| | OPT-2BIT | 0.50080 | 0.4554 | 0.2451 | 0.4899 | 0.7138 | 0.5998 |

0 commit comments

Comments
 (0)