Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified auto_round/auto_scheme/default_alg.abi3.so
Binary file not shown.
2 changes: 2 additions & 0 deletions auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
if self.auto_scheme.enable_torch_compile is None
else self.auto_scheme.enable_torch_compile
)
self.disable_opt_rtn = self.auto_scheme.disable_opt_rtn
self._check_configs()

def _check_configs(self) -> None:
Expand Down Expand Up @@ -89,6 +90,7 @@ def get_layer_config(self) -> dict[str, dict]:
self.tokenizer,
device_map=self.device_map,
enable_torch_compile=self.enable_torch_compile,
disable_opt_rtn=self.disable_opt_rtn,
)
layer_config = self.fallback_gguf_layer_config(layer_config)
return layer_config
Expand Down
14 changes: 6 additions & 8 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ def _gen_auto_scheme(
# mainly using quant_layers and fixed by users
from auto_round.auto_scheme.gen_auto_scheme import GenScheme

if self.enable_torch_compile is False:
logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM")
if not self.enable_torch_compile and self.super_bits is None:
logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM")
gen_scheme = GenScheme(
scheme,
self.model,
Expand Down Expand Up @@ -1275,14 +1275,12 @@ def get_imatrix_hook(module, input, output):

if not hasattr(module, "imatrix"):
module.imatrix = squared
module.imatrix_cnt = input.shape[0]
else:
module.imatrix += squared.to(module.imatrix.device)
module.imatrix_cnt += input.shape[0]

hook_handles = []
for name, module in model.named_modules():
if isinstance(module, self.supported_types) and check_to_quantized(module):
if type(module) in self.supported_types and check_to_quantized(module):
hook = module.register_forward_hook(get_imatrix_hook)
hook_handles.append(hook)
return hook_handles
Expand Down Expand Up @@ -1452,7 +1450,9 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
for module in tqdm(modules, desc="Update weight global scale for fuse module"):
update_fused_layer_global_scales(module)

has_gguf_k = any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", []))
has_gguf_k = (
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
)

self._quantize_embedding_layer()

Expand Down Expand Up @@ -1595,8 +1595,6 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
set_amax_for_all_moe_layers(block, attr_name="act_max")
# Normalize imatrix and quantize layers
for _, m in block.named_modules():
if hasattr(m, "imatrix"):
m.imatrix /= m.imatrix_cnt
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(m.tmp_name)
all_to_quantized_module_names.remove(m.tmp_name)
Expand Down
205 changes: 105 additions & 100 deletions auto_round/data_type/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
from auto_round.logger import logger
from auto_round.utils import get_reciprocal

Expand Down Expand Up @@ -283,48 +285,11 @@ def quant_tensor_asym_dq(
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}


@register_dtype("rtn_int_asym_dq")
def quant_tensor_gguf_asym_dq(
tensor,
bits=4,
v=0,
min_scale=1.0,
max_scale=1.0,
scale_dtype=torch.float16,
tensor_min=None,
tensor_max=None,
q_scale_thresh=1e-5,
imatrix=None,
**kwargs,
):
"""Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
Only fit for iters 0

Args:
tensor (torch.Tensor): Input tensor to quantize.
bits (int): Number of bits for quantization.
group_size (int): Group size for per-group quantization.
v (float): Perturbation added before rounding.
min_scale (float): Minimum allowed scale value.
max_scale (float): Maximum allowed scale value.
scale_dtype (torch.dtype): Data type for quantized scale.
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
q_scale_thresh (float): Threshold to clamp the quantized scale.
super_group_size (int): Number of groups to bundle for secondary quantization.
super_bits (int): Number of bits used in secondary quantization.
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.

Returns:
Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
"""
orig_dtype = tensor.dtype
maxq = 2**bits - 1
group_size = 16 if bits == 2 else 32
@torch.no_grad()
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
super_bits = 4 if bits == 2 else 6
super_group_size = 16 if bits == 2 else 8
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
tensor = tensor.to(torch.float32)
group_size = 16 if bits == 2 else 32
if bits not in [2, 4, 5]:
raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq")
quant_weights = None
Expand Down Expand Up @@ -430,8 +395,52 @@ def quant_tensor_gguf_asym_dq(
d_wmin = d_wmin.unsqueeze(-1)
scale = (d_scale * q_scale).view(-1, 1)
wmin = (d_wmin * q_wmin).view(-1, 1)
inverse_scale = get_reciprocal(scale)
return scale, wmin, d_scale, d_wmin


@register_dtype("rtn_int_asym_dq")
def quant_tensor_gguf_asym_dq(
tensor: torch.Tensor,
bits: int = 4,
v=0,
scale_dtype=torch.float16,
imatrix=None,
scale=None,
wmin=None,
d_scale=None,
d_wmin=None,
**kwargs,
):
"""Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
Only fit for iters 0

Args:
tensor (torch.Tensor): Input tensor to quantize.
bits (int): Number of bits for quantization.
group_size (int): Group size for per-group quantization.
v (float): Perturbation added before rounding.
min_scale (float): Minimum allowed scale value.
max_scale (float): Maximum allowed scale value.
scale_dtype (torch.dtype): Data type for quantized scale.
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
q_scale_thresh (float): Threshold to clamp the quantized scale.
super_group_size (int): Number of groups to bundle for secondary quantization.
super_bits (int): Number of bits used in secondary quantization.
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.

Returns:
Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
"""
orig_dtype = tensor.dtype
maxq = 2**bits - 1
group_size = 16 if bits == 2 else 32
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
tensor = tensor.to(torch.float32)
if scale is None:
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix)

inverse_scale = get_reciprocal(scale)
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
qdq_result = (scale * int_w - wmin).to(orig_dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
Expand Down Expand Up @@ -506,18 +515,58 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
return scale.to(torch.float32), -rmin.to(torch.float32)


@torch.no_grad()
def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K

group_size = 16

if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
if bits == 3:
scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
elif bits == 6:
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
else:
imatrix = imatrix.to(tensor.device)
weights = imatrix.reshape(1, -1)
weights = weights.expand(tensor.numel() // weights.numel(), -1)
quant_weights = weights.reshape(tensor.shape)
if torch.min(quant_weights) == 0:
logger.warning_once(
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
)
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
replace_index = zero_cnt > group_size // 2
if torch.sum(replace_index) > 0:
if bits == 6:
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
else:
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
quant_weights[replace_index] = tmp_quant_weights[replace_index]
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
if torch.sum(mean_replace_index) > 0:
## use mean values to fill zero values
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
tmp_quant_weights = (
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
)
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]

scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
return scale


#
@register_dtype("rtn_int_sym_dq")
def quant_tensor_gguf_sym_dq(
tensor,
bits=3,
v=0,
min_scale=1.0,
max_scale=1.0,
scale_dtype=torch.float16,
tensor_min=None,
tensor_max=None,
q_scale_thresh=1e-5,
imatrix=None,
scale=None,
d_scale=None,
scale_dtype=torch.float16,
**kwargs,
):
"""Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K.
Expand All @@ -537,80 +586,36 @@ def quant_tensor_gguf_sym_dq(
Returns:
Quantized and de-quantized tensor, scale, zero-point
"""
from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants

from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K

if bits not in [3, 6]:
raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.")

maxq = 2 ** (bits - 1)
group_size = 16
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
orig_dtype = tensor.dtype
super_bits = 6 if bits == 3 else 8
super_group_size = 16

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
ggml_type = f"q{bits}_k"
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
orig_dtype = tensor.dtype

tensor = tensor.to(torch.float32)
n_blocks = tensor.nelement() // block_size
# (nb, 16, 16)
tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size)
if scale is None and d_scale is None:
scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype)

if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
if bits == 3:
scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
elif bits == 6:
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
else:
imatrix = imatrix.to(tensor.device)

# if bits == 3:
# # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
# # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape)
# # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor)
# # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
# weights = imatrix.reshape(1, -1)
# weights = weights.expand(tensor.numel() // weights.numel(), -1)
# quant_weights = weights.reshape(tensor.shape)
# elif bits == 6:

weights = imatrix.reshape(1, -1)
weights = weights.expand(tensor.numel() // weights.numel(), -1)
quant_weights = weights.reshape(tensor.shape)
if torch.min(quant_weights) == 0:
logger.warning_once(
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
)
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
replace_index = zero_cnt > group_size // 2
if torch.sum(replace_index) > 0:
if bits == 6:
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
else:
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
quant_weights[replace_index] = tmp_quant_weights[replace_index]
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
if torch.sum(mean_replace_index) > 0:
## use mean values to fill zero values
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
tmp_quant_weights = (
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
)
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]

scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
scale = scale.to(scale_dtype)
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
# conduct double quant
scale, d_scale = double_quant_tensor_sym(scale, super_bits)

scale = scale.unsqueeze(-1)
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
inverse_scale = get_reciprocal(scale)
int_w = torch.round(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
qdq_result = (scale * (int_w - zp)).to(orig_dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)

Expand Down
1 change: 1 addition & 0 deletions auto_round/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class AutoScheme:
dataset: Optional[str] = None # Import Notice no comma for each item
device_map: Optional[Union[str, torch.device, int, dict]] = None
enable_torch_compile: Optional[bool] = None
disable_opt_rtn: bool = True

def __post_init__(self):
if isinstance(self.options, str):
Expand Down
18 changes: 12 additions & 6 deletions docs/step_by_step.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,15 +351,21 @@ ar.quantize_and_save()
The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options.
We tested it on Nvidia A100 80G using torch v2.8.

| Models | Scheme | VRAM Cost <br /> (torch compile) | Time Cost <br /> (torch compile) | VRAM Cost <br /> (w/o torch compile) | Time Cost <br /> (w/o torch compile) |
| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- |
| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options |
| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options |
| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options |
We will try to optimize the VRAM usage in the future.

| Models | Scheme | VRAM Cost <br />(torch compile) | Time Cost<br /> torch compile | VRAM Cost <br />wo torch compile | Time Cost<br /> wo torch compile |
| --------- | ----------------- | ------------------------------- | ----------------------------- | -------------------------------- | -------------------------------- |
| Qwen3-8B | W2A16/W4A16/W8A16 | 34G | 30s * len of options | 61G | 40s * len of options |
| Qwen3-8B | MXFP4/MXFP8 | 36G | 60s * len of options | 54G | 120s * len of options |
| Qwen3-8B | GGUF* | 54G | 30s * len of options | 50G | 23S * len of options |
| Qwen3-32B | W2A16/W4A16/W8A16 | OOM with 240G | --- | OOM with 240G | --- |
| Qwen3-32B | MXFP4/MXFP8 | 160G | 200s * len of options | 200G | 240s * len of options |
| Qwen3-32B | GGUF* | 210G | 80s * len of options | 200G | 60s * len of options |



#### Limitations
Embedding layer is supported in AutoScheme, it will use the best scheme in options.
Embedding layer is not supported in AutoScheme, it will use the best scheme in options.


### RTN mode
Expand Down