|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +from typing import Any, Callable, Union |
14 | 15 |
|
15 | 16 | import torch |
16 | 17 |
|
@@ -285,6 +286,39 @@ def quant_tensor_asym_dq( |
285 | 286 | return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} |
286 | 287 |
|
287 | 288 |
|
| 289 | +def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tensor, bits: int): |
| 290 | + if not isinstance(imatrix, torch.Tensor): |
| 291 | + return imatrix |
| 292 | + |
| 293 | + group_size = 16 if bits == 2 else 32 |
| 294 | + imatrix = imatrix.reshape(-1, imatrix.shape[-1]) |
| 295 | + if torch.min(imatrix) == 0: |
| 296 | + logger.warning_once( |
| 297 | + "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" |
| 298 | + ) |
| 299 | + |
| 300 | + zero_cnt = torch.sum(imatrix == 0, dim=-1) |
| 301 | + replace_index = zero_cnt > group_size // 2 |
| 302 | + if torch.sum(replace_index) > 0: |
| 303 | + ## fallback to no imatrix |
| 304 | + if bits == 2: |
| 305 | + tmp_quant_weights = torch.abs(weight) |
| 306 | + elif bits == 4 or bits == 5: |
| 307 | + sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K |
| 308 | + av_x = torch.sqrt(sigma2) |
| 309 | + tmp_quant_weights = torch.abs(weight) + av_x |
| 310 | + tmp_quant_weights = tmp_quant_weights.to(imatrix.dtype) |
| 311 | + imatrix[replace_index, :] = tmp_quant_weights[replace_index, :] |
| 312 | + mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) |
| 313 | + if torch.sum(mean_replace_index) > 0: |
| 314 | + ## use mean values to fill zero values |
| 315 | + tmp_quant_weights = torch.sum(imatrix, dim=-1) / (imatrix.shape[1] - zero_cnt) |
| 316 | + tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1]) |
| 317 | + replace_idx = imatrix == 0 |
| 318 | + imatrix[replace_idx] = tmp_quant_weights[replace_idx] |
| 319 | + return imatrix.reshape(weight.shape) |
| 320 | + |
| 321 | + |
288 | 322 | @torch.no_grad() |
289 | 323 | def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None): |
290 | 324 | super_bits = 4 if bits == 2 else 6 |
@@ -337,30 +371,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri |
337 | 371 | weights = weights.expand(tensor.numel() // weights.numel(), -1) |
338 | 372 | quant_weights = weights.reshape(tensor.shape) |
339 | 373 |
|
340 | | - if torch.min(quant_weights) == 0: |
341 | | - logger.warning_once( |
342 | | - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" |
343 | | - ) |
344 | | - |
345 | | - zero_cnt = torch.sum(quant_weights == 0, dim=-1) |
346 | | - replace_index = zero_cnt > group_size // 2 |
347 | | - if torch.sum(replace_index) > 0: |
348 | | - ## fallback to no imatrix |
349 | | - if bits == 2: |
350 | | - tmp_quant_weights = torch.abs(tensor) |
351 | | - elif bits == 4 or bits == 5: |
352 | | - sigma2 = ( |
353 | | - torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 |
354 | | - ) ## Note 32 is different from QK_K |
355 | | - av_x = torch.sqrt(sigma2) |
356 | | - tmp_quant_weights = torch.abs(tensor) + av_x |
357 | | - quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] |
358 | | - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) |
359 | | - if torch.sum(mean_replace_index) > 0: |
360 | | - ## use mean values to fill zero values |
361 | | - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) |
362 | | - tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) |
363 | | - quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :] |
| 374 | + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) |
364 | 375 |
|
365 | 376 | # sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K |
366 | 377 | # if imatrix is None: |
@@ -532,27 +543,8 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): |
532 | 543 | weights = imatrix.reshape(1, -1) |
533 | 544 | weights = weights.expand(tensor.numel() // weights.numel(), -1) |
534 | 545 | quant_weights = weights.reshape(tensor.shape) |
535 | | - if torch.min(quant_weights) == 0: |
536 | | - logger.warning_once( |
537 | | - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" |
538 | | - ) |
539 | | - zero_cnt = torch.sum(quant_weights == 0, dim=-1) |
540 | | - replace_index = zero_cnt > group_size // 2 |
541 | | - if torch.sum(replace_index) > 0: |
542 | | - if bits == 6: |
543 | | - quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] |
544 | | - else: |
545 | | - sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K |
546 | | - tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) |
547 | | - quant_weights[replace_index] = tmp_quant_weights[replace_index] |
548 | | - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) |
549 | | - if torch.sum(mean_replace_index) > 0: |
550 | | - ## use mean values to fill zero values |
551 | | - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) |
552 | | - tmp_quant_weights = ( |
553 | | - tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) |
554 | | - ) |
555 | | - quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] |
| 546 | + |
| 547 | + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) |
556 | 548 |
|
557 | 549 | scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) |
558 | 550 | return scale |
|
0 commit comments