Skip to content

Commit 5de9fbc

Browse files
authored
support disable_opt_rtn in auto-scheme (#913)
1 parent 46812de commit 5de9fbc

File tree

6 files changed

+126
-114
lines changed

6 files changed

+126
-114
lines changed
48.3 KB
Binary file not shown.

auto_round/auto_scheme/gen_auto_scheme.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
if self.auto_scheme.enable_torch_compile is None
5252
else self.auto_scheme.enable_torch_compile
5353
)
54+
self.disable_opt_rtn = self.auto_scheme.disable_opt_rtn
5455
self._check_configs()
5556

5657
def _check_configs(self) -> None:
@@ -89,6 +90,7 @@ def get_layer_config(self) -> dict[str, dict]:
8990
self.tokenizer,
9091
device_map=self.device_map,
9192
enable_torch_compile=self.enable_torch_compile,
93+
disable_opt_rtn=self.disable_opt_rtn,
9294
)
9395
layer_config = self.fallback_gguf_layer_config(layer_config)
9496
return layer_config

auto_round/compressors/base.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,8 +463,8 @@ def _gen_auto_scheme(
463463
# mainly using quant_layers and fixed by users
464464
from auto_round.auto_scheme.gen_auto_scheme import GenScheme
465465

466-
if self.enable_torch_compile is False:
467-
logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM")
466+
if not self.enable_torch_compile and self.super_bits is None:
467+
logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM")
468468
gen_scheme = GenScheme(
469469
scheme,
470470
self.model,
@@ -1275,14 +1275,12 @@ def get_imatrix_hook(module, input, output):
12751275

12761276
if not hasattr(module, "imatrix"):
12771277
module.imatrix = squared
1278-
module.imatrix_cnt = input.shape[0]
12791278
else:
12801279
module.imatrix += squared.to(module.imatrix.device)
1281-
module.imatrix_cnt += input.shape[0]
12821280

12831281
hook_handles = []
12841282
for name, module in model.named_modules():
1285-
if isinstance(module, self.supported_types) and check_to_quantized(module):
1283+
if type(module) in self.supported_types and check_to_quantized(module):
12861284
hook = module.register_forward_hook(get_imatrix_hook)
12871285
hook_handles.append(hook)
12881286
return hook_handles
@@ -1452,7 +1450,9 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
14521450
for module in tqdm(modules, desc="Update weight global scale for fuse module"):
14531451
update_fused_layer_global_scales(module)
14541452

1455-
has_gguf_k = any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", []))
1453+
has_gguf_k = (
1454+
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
1455+
)
14561456

14571457
self._quantize_embedding_layer()
14581458

@@ -1595,8 +1595,6 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
15951595
set_amax_for_all_moe_layers(block, attr_name="act_max")
15961596
# Normalize imatrix and quantize layers
15971597
for _, m in block.named_modules():
1598-
if hasattr(m, "imatrix"):
1599-
m.imatrix /= m.imatrix_cnt
16001598
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
16011599
self._quantize_layer_via_rtn(m.tmp_name)
16021600
all_to_quantized_module_names.remove(m.tmp_name)

auto_round/data_type/gguf.py

Lines changed: 105 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from auto_round.data_type.register import register_dtype
1818
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
19+
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES
20+
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
1921
from auto_round.logger import logger
2022
from auto_round.utils import get_reciprocal
2123

@@ -283,48 +285,11 @@ def quant_tensor_asym_dq(
283285
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}
284286

285287

286-
@register_dtype("rtn_int_asym_dq")
287-
def quant_tensor_gguf_asym_dq(
288-
tensor,
289-
bits=4,
290-
v=0,
291-
min_scale=1.0,
292-
max_scale=1.0,
293-
scale_dtype=torch.float16,
294-
tensor_min=None,
295-
tensor_max=None,
296-
q_scale_thresh=1e-5,
297-
imatrix=None,
298-
**kwargs,
299-
):
300-
"""Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
301-
Only fit for iters 0
302-
303-
Args:
304-
tensor (torch.Tensor): Input tensor to quantize.
305-
bits (int): Number of bits for quantization.
306-
group_size (int): Group size for per-group quantization.
307-
v (float): Perturbation added before rounding.
308-
min_scale (float): Minimum allowed scale value.
309-
max_scale (float): Maximum allowed scale value.
310-
scale_dtype (torch.dtype): Data type for quantized scale.
311-
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
312-
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
313-
q_scale_thresh (float): Threshold to clamp the quantized scale.
314-
super_group_size (int): Number of groups to bundle for secondary quantization.
315-
super_bits (int): Number of bits used in secondary quantization.
316-
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
317-
318-
Returns:
319-
Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
320-
"""
321-
orig_dtype = tensor.dtype
322-
maxq = 2**bits - 1
323-
group_size = 16 if bits == 2 else 32
288+
@torch.no_grad()
289+
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
324290
super_bits = 4 if bits == 2 else 6
325291
super_group_size = 16 if bits == 2 else 8
326-
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
327-
tensor = tensor.to(torch.float32)
292+
group_size = 16 if bits == 2 else 32
328293
if bits not in [2, 4, 5]:
329294
raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq")
330295
quant_weights = None
@@ -430,8 +395,52 @@ def quant_tensor_gguf_asym_dq(
430395
d_wmin = d_wmin.unsqueeze(-1)
431396
scale = (d_scale * q_scale).view(-1, 1)
432397
wmin = (d_wmin * q_wmin).view(-1, 1)
433-
inverse_scale = get_reciprocal(scale)
398+
return scale, wmin, d_scale, d_wmin
399+
434400

401+
@register_dtype("rtn_int_asym_dq")
402+
def quant_tensor_gguf_asym_dq(
403+
tensor: torch.Tensor,
404+
bits: int = 4,
405+
v=0,
406+
scale_dtype=torch.float16,
407+
imatrix=None,
408+
scale=None,
409+
wmin=None,
410+
d_scale=None,
411+
d_wmin=None,
412+
**kwargs,
413+
):
414+
"""Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
415+
Only fit for iters 0
416+
417+
Args:
418+
tensor (torch.Tensor): Input tensor to quantize.
419+
bits (int): Number of bits for quantization.
420+
group_size (int): Group size for per-group quantization.
421+
v (float): Perturbation added before rounding.
422+
min_scale (float): Minimum allowed scale value.
423+
max_scale (float): Maximum allowed scale value.
424+
scale_dtype (torch.dtype): Data type for quantized scale.
425+
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
426+
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
427+
q_scale_thresh (float): Threshold to clamp the quantized scale.
428+
super_group_size (int): Number of groups to bundle for secondary quantization.
429+
super_bits (int): Number of bits used in secondary quantization.
430+
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
431+
432+
Returns:
433+
Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
434+
"""
435+
orig_dtype = tensor.dtype
436+
maxq = 2**bits - 1
437+
group_size = 16 if bits == 2 else 32
438+
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
439+
tensor = tensor.to(torch.float32)
440+
if scale is None:
441+
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix)
442+
443+
inverse_scale = get_reciprocal(scale)
435444
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
436445
qdq_result = (scale * int_w - wmin).to(orig_dtype)
437446
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
@@ -506,18 +515,58 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
506515
return scale.to(torch.float32), -rmin.to(torch.float32)
507516

508517

518+
@torch.no_grad()
519+
def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
520+
from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
521+
522+
group_size = 16
523+
524+
if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
525+
if bits == 3:
526+
scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
527+
##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
528+
elif bits == 6:
529+
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
530+
else:
531+
imatrix = imatrix.to(tensor.device)
532+
weights = imatrix.reshape(1, -1)
533+
weights = weights.expand(tensor.numel() // weights.numel(), -1)
534+
quant_weights = weights.reshape(tensor.shape)
535+
if torch.min(quant_weights) == 0:
536+
logger.warning_once(
537+
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
538+
)
539+
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
540+
replace_index = zero_cnt > group_size // 2
541+
if torch.sum(replace_index) > 0:
542+
if bits == 6:
543+
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
544+
else:
545+
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
546+
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
547+
quant_weights[replace_index] = tmp_quant_weights[replace_index]
548+
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
549+
if torch.sum(mean_replace_index) > 0:
550+
## use mean values to fill zero values
551+
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
552+
tmp_quant_weights = (
553+
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
554+
)
555+
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]
556+
557+
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
558+
return scale
559+
560+
561+
#
509562
@register_dtype("rtn_int_sym_dq")
510563
def quant_tensor_gguf_sym_dq(
511564
tensor,
512565
bits=3,
513-
v=0,
514-
min_scale=1.0,
515-
max_scale=1.0,
516-
scale_dtype=torch.float16,
517-
tensor_min=None,
518-
tensor_max=None,
519-
q_scale_thresh=1e-5,
520566
imatrix=None,
567+
scale=None,
568+
d_scale=None,
569+
scale_dtype=torch.float16,
521570
**kwargs,
522571
):
523572
"""Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K.
@@ -537,80 +586,36 @@ def quant_tensor_gguf_sym_dq(
537586
Returns:
538587
Quantized and de-quantized tensor, scale, zero-point
539588
"""
540-
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
541-
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
589+
590+
from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
542591

543592
if bits not in [3, 6]:
544593
raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.")
545594

546595
maxq = 2 ** (bits - 1)
547596
group_size = 16
597+
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
598+
orig_dtype = tensor.dtype
548599
super_bits = 6 if bits == 3 else 8
549600
super_group_size = 16
550-
551-
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
552601
ggml_type = f"q{bits}_k"
553602
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
554-
orig_dtype = tensor.dtype
555-
556603
tensor = tensor.to(torch.float32)
557604
n_blocks = tensor.nelement() // block_size
558605
# (nb, 16, 16)
559606
tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size)
607+
if scale is None and d_scale is None:
608+
scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype)
560609

561-
if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
562-
if bits == 3:
563-
scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
564-
##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
565-
elif bits == 6:
566-
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
567-
else:
568-
imatrix = imatrix.to(tensor.device)
569-
570-
# if bits == 3:
571-
# # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
572-
# # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape)
573-
# # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor)
574-
# # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
575-
# weights = imatrix.reshape(1, -1)
576-
# weights = weights.expand(tensor.numel() // weights.numel(), -1)
577-
# quant_weights = weights.reshape(tensor.shape)
578-
# elif bits == 6:
579-
580-
weights = imatrix.reshape(1, -1)
581-
weights = weights.expand(tensor.numel() // weights.numel(), -1)
582-
quant_weights = weights.reshape(tensor.shape)
583-
if torch.min(quant_weights) == 0:
584-
logger.warning_once(
585-
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
586-
)
587-
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
588-
replace_index = zero_cnt > group_size // 2
589-
if torch.sum(replace_index) > 0:
590-
if bits == 6:
591-
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
592-
else:
593-
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
594-
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
595-
quant_weights[replace_index] = tmp_quant_weights[replace_index]
596-
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
597-
if torch.sum(mean_replace_index) > 0:
598-
## use mean values to fill zero values
599-
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
600-
tmp_quant_weights = (
601-
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
602-
)
603-
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]
604-
605-
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
610+
scale = scale.to(scale_dtype)
606611
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
607612
# conduct double quant
608613
scale, d_scale = double_quant_tensor_sym(scale, super_bits)
609614

610615
scale = scale.unsqueeze(-1)
611616
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
612617
inverse_scale = get_reciprocal(scale)
613-
int_w = torch.round(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
618+
int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
614619
qdq_result = (scale * (int_w - zp)).to(orig_dtype)
615620
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
616621

auto_round/schemes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class AutoScheme:
299299
dataset: Optional[str] = None # Import Notice no comma for each item
300300
device_map: Optional[Union[str, torch.device, int, dict]] = None
301301
enable_torch_compile: Optional[bool] = None
302+
disable_opt_rtn: bool = True
302303

303304
def __post_init__(self):
304305
if isinstance(self.options, str):

docs/step_by_step.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -351,15 +351,21 @@ ar.quantize_and_save()
351351
The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options.
352352
We tested it on Nvidia A100 80G using torch v2.8.
353353
354-
| Models | Scheme | VRAM Cost <br /> (torch compile) | Time Cost <br /> (torch compile) | VRAM Cost <br /> (w/o torch compile) | Time Cost <br /> (w/o torch compile) |
355-
| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- |
356-
| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options |
357-
| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options |
358-
| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options |
354+
We will try to optimize the VRAM usage in the future.
355+
356+
| Models | Scheme | VRAM Cost <br />(torch compile) | Time Cost<br /> torch compile | VRAM Cost <br />wo torch compile | Time Cost<br /> wo torch compile |
357+
| --------- | ----------------- | ------------------------------- | ----------------------------- | -------------------------------- | -------------------------------- |
358+
| Qwen3-8B | W2A16/W4A16/W8A16 | 34G | 30s * len of options | 61G | 40s * len of options |
359+
| Qwen3-8B | MXFP4/MXFP8 | 36G | 60s * len of options | 54G | 120s * len of options |
360+
| Qwen3-8B | GGUF* | 54G | 30s * len of options | 50G | 23S * len of options |
361+
| Qwen3-32B | W2A16/W4A16/W8A16 | OOM with 240G | --- | OOM with 240G | --- |
362+
| Qwen3-32B | MXFP4/MXFP8 | 160G | 200s * len of options | 200G | 240s * len of options |
363+
| Qwen3-32B | GGUF* | 210G | 80s * len of options | 200G | 60s * len of options |
364+
359365
360366
361367
#### Limitations
362-
Embedding layer is supported in AutoScheme, it will use the best scheme in options.
368+
Embedding layer is not supported in AutoScheme, it will use the best scheme in options.
363369
364370
365371
### RTN mode

0 commit comments

Comments
 (0)