Skip to content

Commit 1d91207

Browse files
authored
fix torch compile issue in AutoScheme (#909)
1 parent fe62213 commit 1d91207

File tree

6 files changed

+56
-17
lines changed

6 files changed

+56
-17
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions,
2727

2828

2929
## 🆕 What's New
30-
[2025/10] AutoRound team proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
30+
[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
3131
refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions.
3232

3333
[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
@@ -68,7 +68,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De
6868
Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs)
6969

7070
**Fast mixed bits/data-types scheme generation**
71-
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead.
71+
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme).
7272

7373
**10+ VLMs Support**
7474
Out-of-the-box quantization for 10+ vision-language models [example models](https://huggingface.co/collections/OPEA/vlms-autoround-675bc712fdd6a55ebaf11bfa), [support matrix](https://github.com/intel/auto-round/tree/main/auto_round/mllm#support-matrix)
0 Bytes
Binary file not shown.

auto_round/auto_scheme/gen_auto_scheme.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ class GenScheme:
3030

3131
def __init__(
3232
self,
33-
auto_scheme: AutoScheme, # TODO support shared layer
33+
auto_scheme: AutoScheme,
3434
model: torch.nn.Module,
3535
quant_layer_names: Iterable[str],
3636
fixed_layer_scheme: dict[str, dict],
37-
dataset: str = "pile-10k", # TODO use auto-round dataset
37+
dataset: str = "pile-10k",
3838
device_map: Union[str, torch.device, int, dict, None] = None,
3939
tokenizer=None,
4040
enable_torch_compile=False,
@@ -46,7 +46,11 @@ def __init__(
4646
self.fixed_layer_scheme = fixed_layer_scheme
4747
self.dataset = dataset
4848
self.device_map = device_map if self.auto_scheme.device_map is None else self.auto_scheme.device_map
49-
self.enable_torch_compile = enable_torch_compile
49+
self.enable_torch_compile = (
50+
enable_torch_compile
51+
if self.auto_scheme.enable_torch_compile is None
52+
else self.auto_scheme.enable_torch_compile
53+
)
5054
self._check_configs()
5155

5256
def _check_configs(self) -> None:

auto_round/compressors/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,8 @@ def _gen_auto_scheme(
463463
# mainly using quant_layers and fixed by users
464464
from auto_round.auto_scheme.gen_auto_scheme import GenScheme
465465

466+
if self.enable_torch_compile is False:
467+
logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM")
466468
gen_scheme = GenScheme(
467469
scheme,
468470
self.model,
@@ -583,9 +585,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
583585
self.enable_torch_compile = False
584586
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")
585587

586-
if is_debug_mode() and self.enable_torch_compile:
587-
self.enable_torch_compile = False
588-
logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
588+
# if is_debug_mode() and self.enable_torch_compile:
589+
# self.enable_torch_compile = False
590+
# logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
589591

590592
if (self.data_type.startswith("fp") or self.act_data_type.startswith("fp")) and self.enable_torch_compile:
591593
self.enable_torch_compile = False

auto_round/schemes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class AutoScheme:
298298
seqlen: Optional[int] = None
299299
dataset: Optional[str] = None # Import Notice no comma for each item
300300
device_map: Optional[Union[str, torch.device, int, dict]] = None
301+
enable_torch_compile: Optional[bool] = None
301302

302303
def __post_init__(self):
303304
if isinstance(self.options, str):

docs/step_by_step.md

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ AutoRound supports several Schemes:
119119
- **W8A16**(bits:8,group_size:128,sym:True,act_bits:16)
120120
- **W3A16**(bits:3,group_size:128,sym:True,act_bits:16)
121121
- **W2A16**(bits:2,group_size:128,sym:True,act_bits:16)
122-
- **Mixed bits Weight only**
122+
- **Mixed Bits Weight only**
123123
- **NVFP4**(Experimental feature, recommend exporting to llm-compressor format. data_type:nvfp4,act_data_type:nvfp4,static_global_scale,group_size 16)
124124
- **MXFP4**(**Research feature,no real kernel**, data_type:mxfp4,act_data_type:mxfp4,rceil,group_size 32)
125125
- **FPW8A16**(**Research feature,no real kernel**, data_type:fp8,act_data_type 16:,group_size 0->per tensor )
@@ -160,15 +160,15 @@ CPU, Intel GPU, HPU and CUDA for both quantization and inference.
160160
auto-round --model facebook/opt-125m --scheme "W4A16" --format "auto_gptq,auto_awq,auto_round"
161161
```
162162

163-
- **Best Settings:**
163+
- **AutoRoundBest recipe:**
164164

165165
This setting provides the best accuracy in most scenarios but is 45× slower than the standard AutoRound recipe. It is especially recommended for 2-bit quantization and is a good choice if sufficient resources are available.
166166

167167
```bash
168168
auto-round-best --model facebook/opt-125m --scheme "W4A16" --format "auto_gptq,auto_awq,auto_round"
169169
```
170170

171-
- **Light Settings:**
171+
- **AutoRoundLight Settings:**
172172

173173
This setting offers the best speed (2-3X faster than AutoRound), but it may cause a significant accuracy drop for small models and 2-bit quantization. It is recommended for 4-bit settings and models larger than 3B
174174

@@ -195,7 +195,9 @@ output_dir = "./tmp_autoround"
195195
ar.quantize_and_save(output_dir, format="auto_gptq,auto_awq,auto_round")
196196
```
197197

198-
#### Mixed bits Usage
198+
#### Mixed Bits Usage
199+
AutoRound(>0.8) offers auto-scheme to generate mixed bits recipe autocially, please refer to [AutoScheme](#autoscheme) section for more details.
200+
199201
Auto-GPTQ and Auto-AWQ only support a limited set of mixed-bit configurations. If you're unsure about the details, we recommend using the AutoRound format.
200202

201203
vLLM and SGLang fuse MoE and QKV layers, so it's recommended not to assign different bit widths to these layers.
@@ -279,8 +281,11 @@ W2G64 Average Accuracy of 13 tasks and Time Cost Results(Testing was conducted o
279281

280282
AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer this doc [here](./auto_scheme_acc.md)
281283

284+
We strongly recommend set `enable_torch_compile` to True to save VRAM.
285+
282286
**Please note that mixed data types are supported during tuning, but cannot be exported to real models at this time..**
283-
### CLI Usage
287+
288+
#### CLI Usage
284289
use `iters=200`for tuning.
285290
~~~bash
286291
auto_round \
@@ -292,25 +297,25 @@ auto_round \
292297
--format fake
293298
~~~
294299

295-
### API Usage
300+
#### API Usage
296301
~~~
297302
avg_bits= 3.0
298303
scheme = AutoScheme(avg_bits=avg_bits, options=("W2A16G64“, "W4A16","W8A16"))
299304
ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
300305
ar.quantize_and_save()
301306
~~~
302307

303-
### Hyperparameters in AutoScheme
308+
#### Hyperparameters in AutoScheme
304309
`avg_bits(float)`: Target average bits for the whole model, only to be quantized layer will be counted in the average bits calculation.
305310

306311
`options(Union[str, list[Union[QuantizationScheme, str]])`: the options of quantization schemes to choose from. It could be a string like "W4A16", or a list of strings or QuantizationScheme objects.
307312
308313
`ignore_scale_zp_bits(bool)`: Whether to ignore the bits of scale and zero point in average bits calculation. Default is False.
309314
310-
`shared_layers (Optional[Iterable[Iterable[str]]])` only supported in API now
311-
312315
`device_map (Optional[str,dict,torch.device])` only supported in API now, as auto-scheme used more VRAM than auto-round tuning, so you could set a different device_map for it.
313316
317+
`shared_layers (Optional[Iterable[Iterable[str]]])` only supported in API now
318+
314319
In some serving frameworks, certain layers (e.g., QKV or MoE) are fused to accelerate inference. These fused layers may require the same data type and bit configuration. The shared_layers option simplifies this setup by supporting both regex and full-name matching. **Note that regex matching is applied in a block-wise manner.**
315320
316321
@@ -329,6 +334,33 @@ ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
329334
model, layer_config = ar.quantize()
330335
```
331336
337+
Besides, if you want to fix the scheme for some layers, you could set it via `layer_config` in AutoRound API.
338+
```python
339+
from auto_round import AutoRound, AutoScheme
340+
341+
model_name = "Qwen/Qwen3-8B"
342+
avg_bits = 3.0
343+
scheme = AutoScheme(avg_bits=avg_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_S"), ignore_scale_zp_bits=True)
344+
layer_config = {"lm_head": "GGUF:Q6_K"}
345+
346+
ar = AutoRound(model=model_name, scheme=scheme, layer_config=layer_config, iters=0)
347+
ar.quantize_and_save()
348+
```
349+
350+
#### AutoScheme Cost
351+
The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options.
352+
We tested it on Nvidia A100 80G using torch v2.8.
353+
354+
| Models | Scheme | VRAM Cost <br /> (torch compile) | Time Cost <br /> (torch compile) | VRAM Cost <br /> (w/o torch compile) | Time Cost <br /> (w/o torch compile) |
355+
| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- |
356+
| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options |
357+
| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options |
358+
| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options |
359+
360+
361+
#### Limitations
362+
Embedding layer is supported in AutoScheme, it will use the best scheme in options.
363+
332364
333365
### RTN mode
334366
AutoRound also supports RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results.

0 commit comments

Comments
 (0)