1616
1717from auto_round .data_type .register import register_dtype
1818from auto_round .data_type .utils import reshape_pad_tensor_by_group_size , revert_tensor_by_pad , round_ste
19+ from auto_round .export .export_to_gguf .config import GGML_QUANT_SIZES
20+ from auto_round .export .export_to_gguf .packing import make_q3_quants , make_qx_quants
1921from auto_round .logger import logger
2022from auto_round .utils import get_reciprocal
2123
@@ -283,48 +285,11 @@ def quant_tensor_asym_dq(
283285 return qdq_result , {"scale" : scale , "d_scale" : d_scale }, {"wmin" : wmin , "d_wmin" : d_wmin }
284286
285287
286- @register_dtype ("rtn_int_asym_dq" )
287- def quant_tensor_gguf_asym_dq (
288- tensor ,
289- bits = 4 ,
290- v = 0 ,
291- min_scale = 1.0 ,
292- max_scale = 1.0 ,
293- scale_dtype = torch .float16 ,
294- tensor_min = None ,
295- tensor_max = None ,
296- q_scale_thresh = 1e-5 ,
297- imatrix = None ,
298- ** kwargs ,
299- ):
300- """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
301- Only fit for iters 0
302-
303- Args:
304- tensor (torch.Tensor): Input tensor to quantize.
305- bits (int): Number of bits for quantization.
306- group_size (int): Group size for per-group quantization.
307- v (float): Perturbation added before rounding.
308- min_scale (float): Minimum allowed scale value.
309- max_scale (float): Maximum allowed scale value.
310- scale_dtype (torch.dtype): Data type for quantized scale.
311- tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
312- tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
313- q_scale_thresh (float): Threshold to clamp the quantized scale.
314- super_group_size (int): Number of groups to bundle for secondary quantization.
315- super_bits (int): Number of bits used in secondary quantization.
316- imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
317-
318- Returns:
319- Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
320- """
321- orig_dtype = tensor .dtype
322- maxq = 2 ** bits - 1
323- group_size = 16 if bits == 2 else 32
288+ @torch .no_grad ()
289+ def search_gguf_scale_min_asym (tensor , bits = 4 , scale_dtype = torch .float16 , imatrix = None ):
324290 super_bits = 4 if bits == 2 else 6
325291 super_group_size = 16 if bits == 2 else 8
326- tensor , orig_shape , pad_len = reshape_pad_tensor_by_group_size (tensor , group_size )
327- tensor = tensor .to (torch .float32 )
292+ group_size = 16 if bits == 2 else 32
328293 if bits not in [2 , 4 , 5 ]:
329294 raise ValueError (f"bits={ bits } not supported by rtn_int_asym_dq" )
330295 quant_weights = None
@@ -430,8 +395,52 @@ def quant_tensor_gguf_asym_dq(
430395 d_wmin = d_wmin .unsqueeze (- 1 )
431396 scale = (d_scale * q_scale ).view (- 1 , 1 )
432397 wmin = (d_wmin * q_wmin ).view (- 1 , 1 )
433- inverse_scale = get_reciprocal (scale )
398+ return scale , wmin , d_scale , d_wmin
399+
434400
401+ @register_dtype ("rtn_int_asym_dq" )
402+ def quant_tensor_gguf_asym_dq (
403+ tensor : torch .Tensor ,
404+ bits : int = 4 ,
405+ v = 0 ,
406+ scale_dtype = torch .float16 ,
407+ imatrix = None ,
408+ scale = None ,
409+ wmin = None ,
410+ d_scale = None ,
411+ d_wmin = None ,
412+ ** kwargs ,
413+ ):
414+ """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
415+ Only fit for iters 0
416+
417+ Args:
418+ tensor (torch.Tensor): Input tensor to quantize.
419+ bits (int): Number of bits for quantization.
420+ group_size (int): Group size for per-group quantization.
421+ v (float): Perturbation added before rounding.
422+ min_scale (float): Minimum allowed scale value.
423+ max_scale (float): Maximum allowed scale value.
424+ scale_dtype (torch.dtype): Data type for quantized scale.
425+ tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
426+ tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
427+ q_scale_thresh (float): Threshold to clamp the quantized scale.
428+ super_group_size (int): Number of groups to bundle for secondary quantization.
429+ super_bits (int): Number of bits used in secondary quantization.
430+ imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
431+
432+ Returns:
433+ Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
434+ """
435+ orig_dtype = tensor .dtype
436+ maxq = 2 ** bits - 1
437+ group_size = 16 if bits == 2 else 32
438+ tensor , orig_shape , pad_len = reshape_pad_tensor_by_group_size (tensor , group_size )
439+ tensor = tensor .to (torch .float32 )
440+ if scale is None :
441+ scale , wmin , d_scale , d_wmin = search_gguf_scale_min_asym (tensor , bits , scale_dtype , imatrix )
442+
443+ inverse_scale = get_reciprocal (scale )
435444 int_w = torch .clamp (round_ste ((tensor + wmin ) * inverse_scale + v ), 0 , maxq )
436445 qdq_result = (scale * int_w - wmin ).to (orig_dtype )
437446 qdq_result = revert_tensor_by_pad (qdq_result , orig_shape = orig_shape , pad_len = pad_len )
@@ -506,18 +515,58 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
506515 return scale .to (torch .float32 ), - rmin .to (torch .float32 )
507516
508517
518+ @torch .no_grad ()
519+ def search_gguf_scale_min_sym (tensor , bits , imatrix , scale_dtype ):
520+ from auto_round .export .export_to_gguf .config import K_SCALE_SIZE , QK_K
521+
522+ group_size = 16
523+
524+ if imatrix is None or (imatrix is not None and torch .sum (imatrix ) == 0 ):
525+ if bits == 3 :
526+ scale , int_w = make_q3_quants (tensor , bits = bits , do_rmse = True )
527+ ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
528+ elif bits == 6 :
529+ scale , int_w = make_qx_quants (tensor , bits = bits , rmse_type = 1 , qw = None )
530+ else :
531+ imatrix = imatrix .to (tensor .device )
532+ weights = imatrix .reshape (1 , - 1 )
533+ weights = weights .expand (tensor .numel () // weights .numel (), - 1 )
534+ quant_weights = weights .reshape (tensor .shape )
535+ if torch .min (quant_weights ) == 0 :
536+ logger .warning_once (
537+ "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
538+ )
539+ zero_cnt = torch .sum (quant_weights == 0 , dim = - 1 )
540+ replace_index = zero_cnt > group_size // 2
541+ if torch .sum (replace_index ) > 0 :
542+ if bits == 6 :
543+ quant_weights [replace_index ] = tensor [replace_index ] * tensor [replace_index ]
544+ else :
545+ sigma2 = 2 * torch .sum (torch .pow (tensor , 2 ), dim = - 1 , keepdim = True ) / QK_K
546+ tmp_quant_weights = torch .sqrt (sigma2 + tensor * tensor )
547+ quant_weights [replace_index ] = tmp_quant_weights [replace_index ]
548+ mean_replace_index = (zero_cnt > 0 ) & (zero_cnt <= group_size // 2 )
549+ if torch .sum (mean_replace_index ) > 0 :
550+ ## use mean values to fill zero values
551+ tmp_quant_weights = torch .sum (quant_weights , dim = - 1 ) / (quant_weights .shape [- 1 ] - zero_cnt )
552+ tmp_quant_weights = (
553+ tmp_quant_weights .view (- 1 , 1 ).expand (- 1 , quant_weights .shape [1 ]).reshape (tensor .shape )
554+ )
555+ quant_weights [mean_replace_index ] = tmp_quant_weights [mean_replace_index ]
556+
557+ scale , int_w = make_qx_quants (tensor , bits = bits , rmse_type = 1 , qw = quant_weights )
558+ return scale
559+
560+
561+ #
509562@register_dtype ("rtn_int_sym_dq" )
510563def quant_tensor_gguf_sym_dq (
511564 tensor ,
512565 bits = 3 ,
513- v = 0 ,
514- min_scale = 1.0 ,
515- max_scale = 1.0 ,
516- scale_dtype = torch .float16 ,
517- tensor_min = None ,
518- tensor_max = None ,
519- q_scale_thresh = 1e-5 ,
520566 imatrix = None ,
567+ scale = None ,
568+ d_scale = None ,
569+ scale_dtype = torch .float16 ,
521570 ** kwargs ,
522571):
523572 """Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K.
@@ -537,80 +586,36 @@ def quant_tensor_gguf_sym_dq(
537586 Returns:
538587 Quantized and de-quantized tensor, scale, zero-point
539588 """
540- from auto_round . export . export_to_gguf . config import GGML_QUANT_SIZES , K_SCALE_SIZE , QK_K
541- from auto_round .export .export_to_gguf .packing import make_q3_quants , make_qx_quants
589+
590+ from auto_round .export .export_to_gguf .config import K_SCALE_SIZE , QK_K
542591
543592 if bits not in [3 , 6 ]:
544593 raise KeyError (f"bits={ bits } is not supported by gguf_int_sym_dq, please check." )
545594
546595 maxq = 2 ** (bits - 1 )
547596 group_size = 16
597+ tensor , orig_shape , pad_len = reshape_pad_tensor_by_group_size (tensor , group_size )
598+ orig_dtype = tensor .dtype
548599 super_bits = 6 if bits == 3 else 8
549600 super_group_size = 16
550-
551- tensor , orig_shape , pad_len = reshape_pad_tensor_by_group_size (tensor , group_size )
552601 ggml_type = f"q{ bits } _k"
553602 block_size , type_size = GGML_QUANT_SIZES [ggml_type ]
554- orig_dtype = tensor .dtype
555-
556603 tensor = tensor .to (torch .float32 )
557604 n_blocks = tensor .nelement () // block_size
558605 # (nb, 16, 16)
559606 tensor = tensor .reshape (n_blocks , super_group_size , QK_K // super_group_size )
607+ if scale is None and d_scale is None :
608+ scale = search_gguf_scale_min_sym (tensor , bits , imatrix , scale_dtype )
560609
561- if imatrix is None or (imatrix is not None and torch .sum (imatrix ) == 0 ):
562- if bits == 3 :
563- scale , int_w = make_q3_quants (tensor , bits = bits , do_rmse = True )
564- ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
565- elif bits == 6 :
566- scale , int_w = make_qx_quants (tensor , bits = bits , rmse_type = 1 , qw = None )
567- else :
568- imatrix = imatrix .to (tensor .device )
569-
570- # if bits == 3:
571- # # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
572- # # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape)
573- # # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor)
574- # # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
575- # weights = imatrix.reshape(1, -1)
576- # weights = weights.expand(tensor.numel() // weights.numel(), -1)
577- # quant_weights = weights.reshape(tensor.shape)
578- # elif bits == 6:
579-
580- weights = imatrix .reshape (1 , - 1 )
581- weights = weights .expand (tensor .numel () // weights .numel (), - 1 )
582- quant_weights = weights .reshape (tensor .shape )
583- if torch .min (quant_weights ) == 0 :
584- logger .warning_once (
585- "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
586- )
587- zero_cnt = torch .sum (quant_weights == 0 , dim = - 1 )
588- replace_index = zero_cnt > group_size // 2
589- if torch .sum (replace_index ) > 0 :
590- if bits == 6 :
591- quant_weights [replace_index ] = tensor [replace_index ] * tensor [replace_index ]
592- else :
593- sigma2 = 2 * torch .sum (torch .pow (tensor , 2 ), dim = - 1 , keepdim = True ) / QK_K
594- tmp_quant_weights = torch .sqrt (sigma2 + tensor * tensor )
595- quant_weights [replace_index ] = tmp_quant_weights [replace_index ]
596- mean_replace_index = (zero_cnt > 0 ) & (zero_cnt <= group_size // 2 )
597- if torch .sum (mean_replace_index ) > 0 :
598- ## use mean values to fill zero values
599- tmp_quant_weights = torch .sum (quant_weights , dim = - 1 ) / (quant_weights .shape [- 1 ] - zero_cnt )
600- tmp_quant_weights = (
601- tmp_quant_weights .view (- 1 , 1 ).expand (- 1 , quant_weights .shape [1 ]).reshape (tensor .shape )
602- )
603- quant_weights [mean_replace_index ] = tmp_quant_weights [mean_replace_index ]
604-
605- scale , int_w = make_qx_quants (tensor , bits = bits , rmse_type = 1 , qw = quant_weights )
610+ scale = scale .to (scale_dtype )
606611 scale = torch .where (torch .abs (scale ) < 1e-30 , torch .zeros_like (scale ), scale )
607612 # conduct double quant
608613 scale , d_scale = double_quant_tensor_sym (scale , super_bits )
609614
610615 scale = scale .unsqueeze (- 1 )
611616 zp = torch .full_like (scale , maxq ) # pylint: disable=E1130
612617 inverse_scale = get_reciprocal (scale )
613- int_w = torch . round (tensor * inverse_scale ).clip (- maxq , maxq - 1 ) + maxq
618+ int_w = round_ste (tensor * inverse_scale ).clip (- maxq , maxq - 1 ) + maxq
614619 qdq_result = (scale * (int_w - zp )).to (orig_dtype )
615620 qdq_result = revert_tensor_by_pad (qdq_result , orig_shape = orig_shape , pad_len = pad_len )
616621
0 commit comments