Skip to content

Commit 5e33cbc

Browse files
authored
fix bug of imatrix contains 0 (#955)
1 parent e8bc353 commit 5e33cbc

File tree

5 files changed

+50
-48
lines changed

5 files changed

+50
-48
lines changed

auto_round/compressors/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,8 +1135,10 @@ def get_imatrix_hook(module, input, output):
11351135

11361136
if not hasattr(module, "imatrix"):
11371137
module.imatrix = squared
1138+
module.imatrix_cnt = input.shape[0]
11381139
else:
11391140
module.imatrix += squared.to(module.imatrix.device)
1141+
module.imatrix_cnt += input.shape[0]
11401142

11411143
hook_handles = []
11421144
for name, module in model.named_modules():
@@ -1454,6 +1456,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14541456
set_amax_for_all_moe_layers(block, attr_name="act_max")
14551457
# Normalize imatrix and quantize layers
14561458
for _, m in block.named_modules():
1459+
# fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu
1460+
# https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1
1461+
if hasattr(m, "imatrix"):
1462+
m.imatrix /= m.imatrix_cnt
14571463
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
14581464
self._quantize_layer_via_rtn(m.tmp_name)
14591465
all_to_quantized_module_names.remove(m.tmp_name)

auto_round/compressors/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType.
510510
if model_architecture not in ModelBase._model_classes[ModelType.TEXT]:
511511
logger.warning(
512512
f"Current version of gguf export does not support for {model_architecture},"
513-
" will re-download dependency file."
513+
" will re-download dependency file. Please restart the task."
514514
)
515515
redownload = True
516516
except ModuleNotFoundError as e:

auto_round/data_type/gguf.py

Lines changed: 37 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Any, Callable, Union
1415

1516
import torch
1617

@@ -285,6 +286,39 @@ def quant_tensor_asym_dq(
285286
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}
286287

287288

289+
def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tensor, bits: int):
290+
if not isinstance(imatrix, torch.Tensor):
291+
return imatrix
292+
293+
group_size = 16 if bits == 2 else 32
294+
imatrix = imatrix.reshape(-1, imatrix.shape[-1])
295+
if torch.min(imatrix) == 0:
296+
logger.warning_once(
297+
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
298+
)
299+
300+
zero_cnt = torch.sum(imatrix == 0, dim=-1)
301+
replace_index = zero_cnt > group_size // 2
302+
if torch.sum(replace_index) > 0:
303+
## fallback to no imatrix
304+
if bits == 2:
305+
tmp_quant_weights = torch.abs(weight)
306+
elif bits == 4 or bits == 5:
307+
sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K
308+
av_x = torch.sqrt(sigma2)
309+
tmp_quant_weights = torch.abs(weight) + av_x
310+
tmp_quant_weights = tmp_quant_weights.to(imatrix.dtype)
311+
imatrix[replace_index, :] = tmp_quant_weights[replace_index, :]
312+
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
313+
if torch.sum(mean_replace_index) > 0:
314+
## use mean values to fill zero values
315+
tmp_quant_weights = torch.sum(imatrix, dim=-1) / (imatrix.shape[1] - zero_cnt)
316+
tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1])
317+
replace_idx = imatrix == 0
318+
imatrix[replace_idx] = tmp_quant_weights[replace_idx]
319+
return imatrix.reshape(weight.shape)
320+
321+
288322
@torch.no_grad()
289323
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
290324
super_bits = 4 if bits == 2 else 6
@@ -337,30 +371,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri
337371
weights = weights.expand(tensor.numel() // weights.numel(), -1)
338372
quant_weights = weights.reshape(tensor.shape)
339373

340-
if torch.min(quant_weights) == 0:
341-
logger.warning_once(
342-
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
343-
)
344-
345-
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
346-
replace_index = zero_cnt > group_size // 2
347-
if torch.sum(replace_index) > 0:
348-
## fallback to no imatrix
349-
if bits == 2:
350-
tmp_quant_weights = torch.abs(tensor)
351-
elif bits == 4 or bits == 5:
352-
sigma2 = (
353-
torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32
354-
) ## Note 32 is different from QK_K
355-
av_x = torch.sqrt(sigma2)
356-
tmp_quant_weights = torch.abs(tensor) + av_x
357-
quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :]
358-
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
359-
if torch.sum(mean_replace_index) > 0:
360-
## use mean values to fill zero values
361-
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt)
362-
tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1])
363-
quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :]
374+
quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits)
364375

365376
# sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
366377
# if imatrix is None:
@@ -532,27 +543,8 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
532543
weights = imatrix.reshape(1, -1)
533544
weights = weights.expand(tensor.numel() // weights.numel(), -1)
534545
quant_weights = weights.reshape(tensor.shape)
535-
if torch.min(quant_weights) == 0:
536-
logger.warning_once(
537-
"please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
538-
)
539-
zero_cnt = torch.sum(quant_weights == 0, dim=-1)
540-
replace_index = zero_cnt > group_size // 2
541-
if torch.sum(replace_index) > 0:
542-
if bits == 6:
543-
quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
544-
else:
545-
sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
546-
tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
547-
quant_weights[replace_index] = tmp_quant_weights[replace_index]
548-
mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
549-
if torch.sum(mean_replace_index) > 0:
550-
## use mean values to fill zero values
551-
tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
552-
tmp_quant_weights = (
553-
tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
554-
)
555-
quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]
546+
547+
quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits)
556548

557549
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
558550
return scale

auto_round/data_type/int.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from auto_round.data_type.register import register_dtype
1919
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
20+
from auto_round.logger import logger
2021
from auto_round.utils import get_reciprocal
2122

2223

@@ -44,7 +45,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, f
4445

4546

4647
@register_dtype("rtn_int_sym")
47-
def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
48+
def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs):
4849
"""Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community
4950
5051
Args:
@@ -62,6 +63,7 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
6263
Returns:
6364
Quantized and de-quantized tensor, scale, zero-point
6465
"""
66+
from auto_round.data_type.gguf import _imatrix_handle_zero
6567

6668
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
6769
maxq = 2 ** (bits - 1)
@@ -73,6 +75,8 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
7375
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
7476
imatrix = imatrix.reshape(tensor.shape)
7577

78+
imatrix = _imatrix_handle_zero(imatrix, tensor, bits)
79+
7680
scale = search_scales(tensor, bits, qw=imatrix)
7781
scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh))
7882
int_w = round_ste(tensor / scale + v)

auto_round/export/export_to_gguf/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def prepare_tensors(cls):
412412
skip = False
413413
for tensor_info in cls.gguf_writer.tensors:
414414
if new_name in tensor_info:
415-
logger.warning(f"{new_name} already add to gguf_writer, skip")
415+
logger.info(f"{new_name} already add to gguf_writer, skip")
416416
skip = True
417417
break
418418
if skip:

0 commit comments

Comments
 (0)