Skip to content
6 changes: 6 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,8 +1135,10 @@ def get_imatrix_hook(module, input, output):

if not hasattr(module, "imatrix"):
module.imatrix = squared
module.imatrix_cnt = input.shape[0]
else:
module.imatrix += squared.to(module.imatrix.device)
module.imatrix_cnt += input.shape[0]

hook_handles = []
for name, module in model.named_modules():
Expand Down Expand Up @@ -1454,6 +1456,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
set_amax_for_all_moe_layers(block, attr_name="act_max")
# Normalize imatrix and quantize layers
for _, m in block.named_modules():
# fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu
# https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1
if hasattr(m, "imatrix"):
m.imatrix /= m.imatrix_cnt
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(m.tmp_name)
all_to_quantized_module_names.remove(m.tmp_name)
Expand Down
2 changes: 1 addition & 1 deletion auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType.
if model_architecture not in ModelBase._model_classes[ModelType.TEXT]:
logger.warning(
f"Current version of gguf export does not support for {model_architecture},"
" will re-download dependency file."
" will re-download dependency file. Please restart the task."
)
redownload = True
except ModuleNotFoundError as e:
Expand Down
82 changes: 37 additions & 45 deletions auto_round/data_type/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union

import torch

Expand Down Expand Up @@ -285,6 +286,39 @@ def quant_tensor_asym_dq(
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}


def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tensor, bits: int):
if not isinstance(imatrix, torch.Tensor):
return imatrix

group_size = 16 if bits == 2 else 32
imatrix = imatrix.reshape(-1, imatrix.shape[-1])
if torch.min(imatrix) == 0:
logger.warning_once(
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
)

zero_cnt = torch.sum(imatrix == 0, dim=-1)
replace_index = zero_cnt > group_size // 2
if torch.sum(replace_index) > 0:
## fallback to no imatrix
if bits == 2:
tmp_quant_weights = torch.abs(weight)
elif bits == 4 or bits == 5:
sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K
av_x = torch.sqrt(sigma2)
tmp_quant_weights = torch.abs(weight) + av_x
tmp_quant_weights = tmp_quant_weights.to(imatrix.dtype)
imatrix[replace_index, :] = tmp_quant_weights[replace_index, :]
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
if torch.sum(mean_replace_index) > 0:
## use mean values to fill zero values
tmp_quant_weights = torch.sum(imatrix, dim=-1) / (imatrix.shape[1] - zero_cnt)
tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1])
replace_idx = imatrix == 0
imatrix[replace_idx] = tmp_quant_weights[replace_idx]
return imatrix.reshape(weight.shape)


@torch.no_grad()
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
super_bits = 4 if bits == 2 else 6
Expand Down Expand Up @@ -337,30 +371,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri
weights = weights.expand(tensor.numel() // weights.numel(), -1)
quant_weights = weights.reshape(tensor.shape)

if torch.min(quant_weights) == 0:
logger.warning_once(
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
)

zero_cnt = torch.sum(quant_weights == 0, dim=-1)
replace_index = zero_cnt > group_size // 2
if torch.sum(replace_index) > 0:
## fallback to no imatrix
if bits == 2:
tmp_quant_weights = torch.abs(tensor)
elif bits == 4 or bits == 5:
sigma2 = (
torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32
) ## Note 32 is different from QK_K
av_x = torch.sqrt(sigma2)
tmp_quant_weights = torch.abs(tensor) + av_x
quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :]
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
if torch.sum(mean_replace_index) > 0:
## use mean values to fill zero values
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt)
tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1])
quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :]
quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits)

# sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
# if imatrix is None:
Expand Down Expand Up @@ -532,27 +543,8 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
weights = imatrix.reshape(1, -1)
weights = weights.expand(tensor.numel() // weights.numel(), -1)
quant_weights = weights.reshape(tensor.shape)
if torch.min(quant_weights) == 0:
logger.warning_once(
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
)
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
replace_index = zero_cnt > group_size // 2
if torch.sum(replace_index) > 0:
if bits == 6:
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
else:
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
quant_weights[replace_index] = tmp_quant_weights[replace_index]
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
if torch.sum(mean_replace_index) > 0:
## use mean values to fill zero values
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
tmp_quant_weights = (
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
)
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]

quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits)

scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
return scale
Expand Down
6 changes: 5 additions & 1 deletion auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.logger import logger
from auto_round.utils import get_reciprocal


Expand Down Expand Up @@ -44,7 +45,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, f


@register_dtype("rtn_int_sym")
def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
"""Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community

Args:
Expand All @@ -62,6 +63,7 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
Returns:
Quantized and de-quantized tensor, scale, zero-point
"""
from auto_round.data_type.gguf import _imatrix_handle_zero

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
maxq = 2 ** (bits - 1)
Expand All @@ -73,6 +75,8 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
imatrix = imatrix.reshape(tensor.shape)

imatrix = _imatrix_handle_zero(imatrix, tensor, bits)

scale = search_scales(tensor, bits, qw=imatrix)
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
int_w = round_ste(tensor / scale + v)
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/export_to_gguf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def prepare_tensors(cls):
skip = False
for tensor_info in cls.gguf_writer.tensors:
if new_name in tensor_info:
logger.warning(f"{new_name} already add to gguf_writer, skip")
logger.info(f"{new_name} already add to gguf_writer, skip")
skip = True
break
if skip:
Expand Down