@@ -3323,6 +3323,12 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
33233323 }
33243324}
33253325
3326+ static int iq1_sort_helper (const void * left , const void * right ) {
3327+ const float * l = left ;
3328+ const float * r = right ;
3329+ return * l < * r ? -1 : * l > * r ? 1 : 0 ;
3330+ }
3331+
33263332static void quantize_row_iq2_xs_impl (const float * GGML_RESTRICT x , void * GGML_RESTRICT vy , int64_t n , const float * GGML_RESTRICT quant_weights ) {
33273333
33283334 const int gindex = iq2_data_index (GGML_TYPE_IQ2_XS );
@@ -3353,6 +3359,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
33533359 bool is_on_grid_aux [2 ];
33543360 uint8_t block_signs [2 ];
33553361 uint16_t q2 [2 * (QK_K /16 )];
3362+ uint16_t index [2 ], aux_index [2 ];
3363+ float sumx [17 ], sumw [17 ], pairs [32 ];
3364+ int * int_pairs = (int * )(pairs + 1 );
33563365
33573366 for (int ibl = 0 ; ibl < nbl ; ++ ibl ) {
33583367
@@ -3401,11 +3410,35 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34013410 memset (L , 0 , 16 );
34023411 continue ;
34033412 }
3404- float best = 0 ;
3405- float scale = max /(2 * kMaxQ - 1 );
3413+ for (int j = 0 ; j < 16 ; ++ j ) {
3414+ pairs [2 * j ] = xval [j ];
3415+ int_pairs [2 * j ] = j ;
3416+ }
3417+ qsort (pairs , 16 , 2 * sizeof (float ), iq1_sort_helper );
3418+ {
3419+ sumx [0 ] = sumw [0 ] = 0 ;
3420+ for (int j = 0 ; j < 16 ; ++ j ) {
3421+ int i = int_pairs [2 * j ];
3422+ sumx [j + 1 ] = sumx [j ] + weight [i ]* xval [i ];
3423+ sumw [j + 1 ] = sumw [j ] + weight [i ];
3424+ }
3425+ }
3426+ float best = 0 , scale = 0 ;
3427+ for (int i1 = 0 ; i1 <= 16 ; ++ i1 ) {
3428+ for (int i2 = i1 ; i2 <= 16 ; ++ i2 ) {
3429+ float sumqx = (sumx [i1 ] - sumx [0 ])* 1 + (sumx [i2 ] - sumx [i1 ])* 3 + (sumx [16 ] - sumx [i2 ])* 5 ;
3430+ float sumq2 = (sumw [i1 ] - sumw [0 ])* 1 + (sumw [i2 ] - sumw [i1 ])* 9 + (sumw [16 ] - sumw [i2 ])* 25 ;
3431+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
3432+ scale = sumqx /sumq2 ; best = scale * sumqx ;
3433+ }
3434+ }
3435+ }
3436+ best = 0 ;
3437+ float eff_max = scale * (2 * kMaxQ - 1 );
34063438 is_on_grid [0 ] = is_on_grid [1 ] = true;
3407- for (int is = -9 ; is <= 9 ; ++ is ) {
3408- float id = (2 * kMaxQ - 1 + is * 0.1f )/max ;
3439+ index [0 ] = index [1 ] = 0 ;
3440+ for (int is = -7 ; is <= 7 ; ++ is ) {
3441+ float id = (2 * kMaxQ - 1 + is * 0.1f )/eff_max ;
34093442 float this_scale = 1 /id ;
34103443 for (int k = 0 ; k < 2 ; ++ k ) {
34113444 for (int i = 0 ; i < 8 ; ++ i ) {
@@ -3421,6 +3454,7 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34213454 const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
34223455 grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , this_scale , Laux + 8 * k );
34233456 }
3457+ aux_index [k ] = grid_index ;
34243458 }
34253459 float sumqx = 0 , sumq2 = 0 ;
34263460 for (int i = 0 ; i < 16 ; ++ i ) {
@@ -3433,35 +3467,45 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34333467 scale = sumqx /sumq2 ; best = scale * sumqx ;
34343468 for (int i = 0 ; i < 16 ; ++ i ) L [i ] = Laux [i ];
34353469 for (int k = 0 ; k < 2 ; ++ k ) is_on_grid [k ] = is_on_grid_aux [k ];
3470+ for (int k = 0 ; k < 2 ; ++ k ) index [k ] = aux_index [k ];
34363471 }
34373472 }
3438- int n_not_ongrid = 0 ;
3439- for (int k = 0 ; k < 2 ; ++ k ) if (!is_on_grid [k ]) ++ n_not_ongrid ;
3440- if (n_not_ongrid > 0 && scale > 0 ) {
3441- float id = 1 /scale ;
3442- for (int k = 0 ; k < 2 ; ++ k ) {
3443- if (is_on_grid [k ]) continue ;
3444- uint16_t u = 0 ;
3445- for (int i = 0 ; i < 8 ; ++ i ) {
3446- int l = nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3447- l = MAX (0 , MIN (kMaxQ - 1 , l ));
3448- u |= (l << 2 * i );
3449- L [8 * k + i ] = l ;
3473+ if (scale ) {
3474+ for (int iter = 0 ; iter < 3 ; ++ iter ) {
3475+ float id = 1 /scale ;
3476+ bool changed = false;
3477+ for (int k = 0 ; k < 2 ; ++ k ) {
3478+ uint16_t u = 0 ;
3479+ for (int i = 0 ; i < 8 ; ++ i ) {
3480+ int l = nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3481+ l = MAX (0 , MIN (kMaxQ - 1 , l ));
3482+ u |= (l << 2 * i );
3483+ Laux [8 * k + i ] = l ;
3484+ }
3485+ int grid_index = kmap_q2xs [u ];
3486+ if (grid_index < 0 ) {
3487+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
3488+ grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , scale , Laux + 8 * k );
3489+ }
3490+ aux_index [k ] = grid_index ;
3491+ if (grid_index != index [k ]) changed = true;
34503492 }
3451- int grid_index = kmap_q2xs [u ];
3452- if (grid_index < 0 ) {
3453- const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
3454- grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , scale , L + 8 * k );
3493+ if (!changed ) break ;
3494+ float sumqx = 0 , sumq2 = 0 ;
3495+ for (int i = 0 ; i < 16 ; ++ i ) {
3496+ float w = weight [i ];
3497+ float q = 2 * Laux [i ] + 1 ;
3498+ sumqx += w * xval [i ]* q ;
3499+ sumq2 += w * q * q ;
34553500 }
3501+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
3502+ scale = sumqx /sumq2 ;
3503+ best = scale * sumqx ;
3504+ memcpy (L , Laux , 16 );
3505+ for (int k = 0 ; k < 2 ; ++ k ) index [k ] = aux_index [k ];
3506+ }
3507+ else break ;
34563508 }
3457- float sumqx = 0 , sumq2 = 0 ;
3458- for (int i = 0 ; i < 16 ; ++ i ) {
3459- float w = weight [i ];
3460- float q = 2 * L [i ] + 1 ;
3461- sumqx += w * xval [i ]* q ;
3462- sumq2 += w * q * q ;
3463- }
3464- if (sumq2 > 0 ) scale = sumqx /sumq2 ;
34653509 }
34663510 if (scale < 0 ) {
34673511 scale = - scale ;
@@ -3492,13 +3536,34 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34923536 float d = max_scale /31 ;
34933537 y [ibl ].d = GGML_FP32_TO_FP16 (d );
34943538 float id = 1 /d ;
3539+ float sumqx = 0 , sumq2 = 0 ;
34953540 for (int ib = 0 ; ib < QK_K /16 ; ++ ib ) {
34963541 int l = nearest_int (0.5f * (id * scales [ib ]- 1 ));
34973542 l = MAX (0 , MIN (15 , l ));
34983543 if (ib %2 == 0 ) y [ibl ].scales [ib /2 ] = l ;
34993544 else y [ibl ].scales [ib /2 ] |= (l << 4 );
3545+ l = 2 * l + 1 ;
3546+ const float * xb = xbl + 16 * ib ;
3547+ if (quant_weights ) {
3548+ const float * qw = quant_weights + QK_K * ibl + 16 * ib ;
3549+ for (int i = 0 ; i < 16 ; ++ i ) weight [i ] = qw [i ] * sqrtf (sigma2 + xb [i ]* xb [i ]);
3550+ } else {
3551+ for (int i = 0 ; i < 16 ; ++ i ) weight [i ] = 0.25f * sigma2 + xb [i ]* xb [i ];
3552+ }
3553+ for (int k = 0 ; k < 2 ; ++ k ) {
3554+ int grid_index = q2 [2 * ib + k ] & 511 ;
3555+ const int8_t * grid = (const int8_t * )(iq2xs_grid + grid_index );
3556+ const uint8_t signs = ksigns_iq2xs [q2 [2 * ib + k ] >> 9 ];
3557+ for (int j = 0 ; j < 8 ; ++ j ) {
3558+ float w = weight [8 * k + j ];
3559+ float q = 0.125f * l * grid [j ]* (signs & kmask_iq2xs [j ] ? -1.f : 1.f );
3560+ sumqx += w * q * xb [8 * k + j ];
3561+ sumq2 += w * q * q ;
3562+ }
3563+ }
35003564 }
35013565 memcpy (y [ibl ].qs , q2 , QK_K /4 );
3566+ if (sumq2 > 0 ) y [ibl ].d = GGML_FP32_TO_FP16 (1.05f * sumqx /sumq2 );
35023567
35033568 }
35043569}
@@ -4304,12 +4369,6 @@ static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, c
43044369 return grid_index ;
43054370}
43064371
4307- static int iq1_sort_helper (const void * left , const void * right ) {
4308- const float * l = left ;
4309- const float * r = right ;
4310- return * l < * r ? -1 : * l > * r ? 1 : 0 ;
4311- }
4312-
43134372#define IQ1S_BLOCK_SIZE 32
43144373#define IQ1M_BLOCK_SIZE 16
43154374static void quantize_row_iq1_s_impl (const float * GGML_RESTRICT x , void * GGML_RESTRICT vy , int64_t n , const float * GGML_RESTRICT quant_weights ,
0 commit comments