@@ -3319,6 +3319,12 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
33193319    }
33203320}
33213321
3322+ static  int  iq1_sort_helper (const  void  *  left , const  void  *  right ) {
3323+     const  float  *  l  =  left ;
3324+     const  float  *  r  =  right ;
3325+     return  * l  <  * r  ? -1  : * l  >  * r  ? 1  : 0 ;
3326+ }
3327+ 
33223328static  void  quantize_row_iq2_xs_impl (const  float  *  GGML_RESTRICT  x , void  *  GGML_RESTRICT  vy , int64_t  n , const  float  *  GGML_RESTRICT  quant_weights ) {
33233329
33243330    const  int  gindex  =  iq2_data_index (GGML_TYPE_IQ2_XS );
@@ -3349,6 +3355,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
33493355    bool    is_on_grid_aux [2 ];
33503356    uint8_t  block_signs [2 ];
33513357    uint16_t  q2 [2 * (QK_K /16 )];
3358+     uint16_t  index [2 ], aux_index [2 ];
3359+     float  sumx [17 ], sumw [17 ], pairs [32 ];
3360+     int  *  int_pairs  =  (int  * )(pairs  +  1 );
33523361
33533362    for  (int  ibl  =  0 ; ibl  <  nbl ; ++ ibl ) {
33543363
@@ -3397,11 +3406,35 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
33973406                memset (L , 0 , 16 );
33983407                continue ;
33993408            }
3400-             float  best  =  0 ;
3401-             float  scale  =  max /(2 * kMaxQ - 1 );
3409+             for  (int  j  =  0 ; j  <  16 ; ++ j ) {
3410+                 pairs [2 * j ] =  xval [j ];
3411+                 int_pairs [2 * j ] =  j ;
3412+             }
3413+             qsort (pairs , 16 , 2 * sizeof (float ), iq1_sort_helper );
3414+             {
3415+                 sumx [0 ] =  sumw [0 ] =  0 ;
3416+                 for  (int  j  =  0 ; j  <  16 ; ++ j ) {
3417+                     int  i  =  int_pairs [2 * j ];
3418+                     sumx [j + 1 ] =  sumx [j ] +  weight [i ]* xval [i ];
3419+                     sumw [j + 1 ] =  sumw [j ] +  weight [i ];
3420+                 }
3421+             }
3422+             float  best  =  0 , scale  =  0 ;
3423+             for  (int  i1  =  0 ; i1  <= 16 ; ++ i1 ) {
3424+                 for  (int  i2  =  i1 ; i2  <= 16 ; ++ i2 ) {
3425+                     float  sumqx  =  (sumx [i1 ] -  sumx [0 ])* 1  +  (sumx [i2 ] -  sumx [i1 ])* 3  +  (sumx [16 ] -  sumx [i2 ])* 5 ;
3426+                     float  sumq2  =  (sumw [i1 ] -  sumw [0 ])* 1  +  (sumw [i2 ] -  sumw [i1 ])* 9  +  (sumw [16 ] -  sumw [i2 ])* 25 ;
3427+                     if  (sumq2  >  0  &&  sumqx * sumqx  >  best * sumq2 ) {
3428+                         scale  =  sumqx /sumq2 ; best  =  scale * sumqx ;
3429+                     }
3430+                 }
3431+             }
3432+             best  =  0 ;
3433+             float  eff_max  =  scale * (2 * kMaxQ  -  1 );
34023434            is_on_grid [0 ] =  is_on_grid [1 ] =  true;
3403-             for  (int  is  =  -9 ; is  <= 9 ; ++ is ) {
3404-                 float  id  =  (2 * kMaxQ - 1 + is * 0.1f )/max ;
3435+             index [0 ] =  index [1 ] =  0 ;
3436+             for  (int  is  =  -7 ; is  <= 7 ; ++ is ) {
3437+                 float  id  =  (2 * kMaxQ - 1 + is * 0.1f )/eff_max ;
34053438                float  this_scale  =  1 /id ;
34063439                for  (int  k  =  0 ; k  <  2 ; ++ k ) {
34073440                    for  (int  i  =  0 ; i  <  8 ; ++ i ) {
@@ -3417,6 +3450,7 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34173450                        const  uint16_t  *  neighbours  =  kneighbors_q2xs  -  kmap_q2xs [u ] -  1 ;
34183451                        grid_index  =  iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval  +  8 * k , waux  +  8 * k , this_scale , Laux  +  8 * k );
34193452                    }
3453+                     aux_index [k ] =  grid_index ;
34203454                }
34213455                float  sumqx  =  0 , sumq2  =  0 ;
34223456                for  (int  i  =  0 ; i  <  16 ; ++ i ) {
@@ -3429,35 +3463,45 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34293463                    scale  =  sumqx /sumq2 ; best  =  scale * sumqx ;
34303464                    for  (int  i  =  0 ; i  <  16 ; ++ i ) L [i ] =  Laux [i ];
34313465                    for  (int  k  =  0 ; k  <   2 ; ++ k ) is_on_grid [k ] =  is_on_grid_aux [k ];
3466+                     for  (int  k  =  0 ; k  <   2 ; ++ k ) index [k ] =  aux_index [k ];
34323467                }
34333468            }
3434-             int  n_not_ongrid  =  0 ;
3435-             for  (int  k  =  0 ; k  <  2 ; ++ k ) if  (!is_on_grid [k ]) ++ n_not_ongrid ;
3436-             if  (n_not_ongrid  >  0  &&  scale  >  0 ) {
3437-                 float  id  =  1 /scale ;
3438-                 for  (int  k  =  0 ; k  <  2 ; ++ k ) {
3439-                     if  (is_on_grid [k ]) continue ;
3440-                     uint16_t  u  =  0 ;
3441-                     for  (int  i  =  0 ; i  <  8 ; ++ i ) {
3442-                         int  l  =  nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3443-                         l  =  MAX (0 , MIN (kMaxQ - 1 , l ));
3444-                         u  |= (l  << 2 * i );
3445-                         L [8 * k  +  i ] =  l ;
3469+             if  (scale ) {
3470+                 for  (int  iter  =  0 ; iter  <  3 ; ++ iter ) {
3471+                     float  id  =  1 /scale ;
3472+                     bool  changed  =  false;
3473+                     for  (int  k  =  0 ; k  <  2 ; ++ k ) {
3474+                         uint16_t  u  =  0 ;
3475+                         for  (int  i  =  0 ; i  <  8 ; ++ i ) {
3476+                             int  l  =  nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3477+                             l  =  MAX (0 , MIN (kMaxQ - 1 , l ));
3478+                             u  |= (l  << 2 * i );
3479+                             Laux [8 * k  +  i ] =  l ;
3480+                         }
3481+                         int  grid_index  =  kmap_q2xs [u ];
3482+                         if  (grid_index  <  0 ) {
3483+                             const  uint16_t  *  neighbours  =  kneighbors_q2xs  -  kmap_q2xs [u ] -  1 ;
3484+                             grid_index  =  iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval  +  8 * k , waux  +  8 * k , scale , Laux  +  8 * k );
3485+                         }
3486+                         aux_index [k ] =  grid_index ;
3487+                         if  (grid_index  !=  index [k ]) changed  =  true;
34463488                    }
3447-                     int  grid_index  =  kmap_q2xs [u ];
3448-                     if  (grid_index  <  0 ) {
3449-                         const  uint16_t  *  neighbours  =  kneighbors_q2xs  -  kmap_q2xs [u ] -  1 ;
3450-                         grid_index  =  iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval  +  8 * k , waux  +  8 * k , scale , L  +  8 * k );
3489+                     if  (!changed ) break ;
3490+                     float  sumqx  =  0 , sumq2  =  0 ;
3491+                     for  (int  i  =  0 ; i  <  16 ; ++ i ) {
3492+                         float  w  =  weight [i ];
3493+                         float  q  =  2 * Laux [i ] +  1 ;
3494+                         sumqx  +=  w * xval [i ]* q ;
3495+                         sumq2  +=  w * q * q ;
34513496                    }
3497+                     if  (sumq2  >  0  &&  sumqx * sumqx  >  best * sumq2 ) {
3498+                         scale  =  sumqx /sumq2 ;
3499+                         best  =  scale * sumqx ;
3500+                         memcpy (L , Laux , 16 );
3501+                         for  (int  k  =  0 ; k  <  2 ; ++ k ) index [k ] =  aux_index [k ];
3502+                     }
3503+                     else  break ;
34523504                }
3453-                 float  sumqx  =  0 , sumq2  =  0 ;
3454-                 for  (int  i  =  0 ; i  <  16 ; ++ i ) {
3455-                     float  w  =  weight [i ];
3456-                     float  q  =  2 * L [i ] +  1 ;
3457-                     sumqx  +=  w * xval [i ]* q ;
3458-                     sumq2  +=  w * q * q ;
3459-                 }
3460-                 if  (sumq2  >  0 ) scale  =  sumqx /sumq2 ;
34613505            }
34623506            if  (scale  <  0 ) {
34633507                scale  =  - scale ;
@@ -3488,13 +3532,34 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34883532        float  d  =  max_scale /31 ;
34893533        y [ibl ].d  =  GGML_FP32_TO_FP16 (d );
34903534        float  id  =  1 /d ;
3535+         float  sumqx  =  0 , sumq2  =  0 ;
34913536        for  (int  ib  =  0 ; ib  <  QK_K /16 ; ++ ib ) {
34923537            int  l  =  nearest_int (0.5f * (id * scales [ib ]- 1 ));
34933538            l  =  MAX (0 , MIN (15 , l ));
34943539            if  (ib %2  ==  0 ) y [ibl ].scales [ib /2 ] =  l ;
34953540            else  y [ibl ].scales [ib /2 ] |= (l  << 4 );
3541+             l  =  2 * l  +  1 ;
3542+             const  float  *  xb  =  xbl  +  16 * ib ;
3543+             if  (quant_weights ) {
3544+                 const  float  *  qw  =  quant_weights  +  QK_K * ibl  +  16 * ib ;
3545+                 for  (int  i  =  0 ; i  <  16 ; ++ i ) weight [i ] =  qw [i ] *  sqrtf (sigma2  +  xb [i ]* xb [i ]);
3546+             } else  {
3547+                 for  (int  i  =  0 ; i  <  16 ; ++ i ) weight [i ] =  0.25f * sigma2  +  xb [i ]* xb [i ];
3548+             }
3549+             for  (int  k  =  0 ; k  <  2 ; ++ k ) {
3550+                 int  grid_index  =  q2 [2 * ib + k ] &  511 ;
3551+                 const  int8_t  *  grid  =  (const  int8_t  * )(iq2xs_grid  +  grid_index );
3552+                 const  uint8_t   signs  =  ksigns_iq2xs [q2 [2 * ib + k ] >> 9 ];
3553+                 for  (int  j  =  0 ; j  <  8 ; ++ j ) {
3554+                     float  w  =  weight [8 * k + j ];
3555+                     float  q  =  0.125f * l * grid [j ]* (signs  &  kmask_iq2xs [j ] ? -1.f  : 1.f );
3556+                     sumqx  +=  w * q * xb [8 * k + j ];
3557+                     sumq2  +=  w * q * q ;
3558+                 }
3559+             }
34963560        }
34973561        memcpy (y [ibl ].qs , q2 , QK_K /4 );
3562+         if  (sumq2  >  0 ) y [ibl ].d  =  GGML_FP32_TO_FP16 (1.05f * sumqx /sumq2 );
34983563
34993564    }
35003565}
@@ -4300,12 +4365,6 @@ static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, c
43004365    return  grid_index ;
43014366}
43024367
4303- static  int  iq1_sort_helper (const  void  *  left , const  void  *  right ) {
4304-     const  float  *  l  =  left ;
4305-     const  float  *  r  =  right ;
4306-     return  * l  <  * r  ? -1  : * l  >  * r  ? 1  : 0 ;
4307- }
4308- 
43094368#define  IQ1S_BLOCK_SIZE  32
43104369#define  IQ1M_BLOCK_SIZE  16
43114370static  void  quantize_row_iq1_s_impl (const  float  *  GGML_RESTRICT  x , void  *  GGML_RESTRICT  vy , int64_t  n , const  float  *  GGML_RESTRICT  quant_weights ,
0 commit comments