@@ -3319,6 +3319,12 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML
33193319 }
33203320}
33213321
3322+ static int iq1_sort_helper (const void * left , const void * right ) {
3323+ const float * l = left ;
3324+ const float * r = right ;
3325+ return * l < * r ? -1 : * l > * r ? 1 : 0 ;
3326+ }
3327+
33223328static void quantize_row_iq2_xs_impl (const float * GGML_RESTRICT x , void * GGML_RESTRICT vy , int64_t n , const float * GGML_RESTRICT quant_weights ) {
33233329
33243330 const int gindex = iq2_data_index (GGML_TYPE_IQ2_XS );
@@ -3349,6 +3355,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
33493355 bool is_on_grid_aux [2 ];
33503356 uint8_t block_signs [2 ];
33513357 uint16_t q2 [2 * (QK_K /16 )];
3358+ uint16_t index [2 ], aux_index [2 ];
3359+ float sumx [17 ], sumw [17 ], pairs [32 ];
3360+ int * int_pairs = (int * )(pairs + 1 );
33523361
33533362 for (int ibl = 0 ; ibl < nbl ; ++ ibl ) {
33543363
@@ -3397,11 +3406,35 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
33973406 memset (L , 0 , 16 );
33983407 continue ;
33993408 }
3400- float best = 0 ;
3401- float scale = max /(2 * kMaxQ - 1 );
3409+ for (int j = 0 ; j < 16 ; ++ j ) {
3410+ pairs [2 * j ] = xval [j ];
3411+ int_pairs [2 * j ] = j ;
3412+ }
3413+ qsort (pairs , 16 , 2 * sizeof (float ), iq1_sort_helper );
3414+ {
3415+ sumx [0 ] = sumw [0 ] = 0 ;
3416+ for (int j = 0 ; j < 16 ; ++ j ) {
3417+ int i = int_pairs [2 * j ];
3418+ sumx [j + 1 ] = sumx [j ] + weight [i ]* xval [i ];
3419+ sumw [j + 1 ] = sumw [j ] + weight [i ];
3420+ }
3421+ }
3422+ float best = 0 , scale = 0 ;
3423+ for (int i1 = 0 ; i1 <= 16 ; ++ i1 ) {
3424+ for (int i2 = i1 ; i2 <= 16 ; ++ i2 ) {
3425+ float sumqx = (sumx [i1 ] - sumx [0 ])* 1 + (sumx [i2 ] - sumx [i1 ])* 3 + (sumx [16 ] - sumx [i2 ])* 5 ;
3426+ float sumq2 = (sumw [i1 ] - sumw [0 ])* 1 + (sumw [i2 ] - sumw [i1 ])* 9 + (sumw [16 ] - sumw [i2 ])* 25 ;
3427+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
3428+ scale = sumqx /sumq2 ; best = scale * sumqx ;
3429+ }
3430+ }
3431+ }
3432+ best = 0 ;
3433+ float eff_max = scale * (2 * kMaxQ - 1 );
34023434 is_on_grid [0 ] = is_on_grid [1 ] = true;
3403- for (int is = -9 ; is <= 9 ; ++ is ) {
3404- float id = (2 * kMaxQ - 1 + is * 0.1f )/max ;
3435+ index [0 ] = index [1 ] = 0 ;
3436+ for (int is = -7 ; is <= 7 ; ++ is ) {
3437+ float id = (2 * kMaxQ - 1 + is * 0.1f )/eff_max ;
34053438 float this_scale = 1 /id ;
34063439 for (int k = 0 ; k < 2 ; ++ k ) {
34073440 for (int i = 0 ; i < 8 ; ++ i ) {
@@ -3417,6 +3450,7 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34173450 const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
34183451 grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , this_scale , Laux + 8 * k );
34193452 }
3453+ aux_index [k ] = grid_index ;
34203454 }
34213455 float sumqx = 0 , sumq2 = 0 ;
34223456 for (int i = 0 ; i < 16 ; ++ i ) {
@@ -3429,35 +3463,45 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34293463 scale = sumqx /sumq2 ; best = scale * sumqx ;
34303464 for (int i = 0 ; i < 16 ; ++ i ) L [i ] = Laux [i ];
34313465 for (int k = 0 ; k < 2 ; ++ k ) is_on_grid [k ] = is_on_grid_aux [k ];
3466+ for (int k = 0 ; k < 2 ; ++ k ) index [k ] = aux_index [k ];
34323467 }
34333468 }
3434- int n_not_ongrid = 0 ;
3435- for (int k = 0 ; k < 2 ; ++ k ) if (!is_on_grid [k ]) ++ n_not_ongrid ;
3436- if (n_not_ongrid > 0 && scale > 0 ) {
3437- float id = 1 /scale ;
3438- for (int k = 0 ; k < 2 ; ++ k ) {
3439- if (is_on_grid [k ]) continue ;
3440- uint16_t u = 0 ;
3441- for (int i = 0 ; i < 8 ; ++ i ) {
3442- int l = nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3443- l = MAX (0 , MIN (kMaxQ - 1 , l ));
3444- u |= (l << 2 * i );
3445- L [8 * k + i ] = l ;
3469+ if (scale ) {
3470+ for (int iter = 0 ; iter < 3 ; ++ iter ) {
3471+ float id = 1 /scale ;
3472+ bool changed = false;
3473+ for (int k = 0 ; k < 2 ; ++ k ) {
3474+ uint16_t u = 0 ;
3475+ for (int i = 0 ; i < 8 ; ++ i ) {
3476+ int l = nearest_int (0.5f * (id * xval [8 * k + i ]- 1 ));
3477+ l = MAX (0 , MIN (kMaxQ - 1 , l ));
3478+ u |= (l << 2 * i );
3479+ Laux [8 * k + i ] = l ;
3480+ }
3481+ int grid_index = kmap_q2xs [u ];
3482+ if (grid_index < 0 ) {
3483+ const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
3484+ grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , scale , Laux + 8 * k );
3485+ }
3486+ aux_index [k ] = grid_index ;
3487+ if (grid_index != index [k ]) changed = true;
34463488 }
3447- int grid_index = kmap_q2xs [u ];
3448- if (grid_index < 0 ) {
3449- const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs [u ] - 1 ;
3450- grid_index = iq2_find_best_neighbour (neighbours , kgrid_q2xs , xval + 8 * k , waux + 8 * k , scale , L + 8 * k );
3489+ if (!changed ) break ;
3490+ float sumqx = 0 , sumq2 = 0 ;
3491+ for (int i = 0 ; i < 16 ; ++ i ) {
3492+ float w = weight [i ];
3493+ float q = 2 * Laux [i ] + 1 ;
3494+ sumqx += w * xval [i ]* q ;
3495+ sumq2 += w * q * q ;
34513496 }
3497+ if (sumq2 > 0 && sumqx * sumqx > best * sumq2 ) {
3498+ scale = sumqx /sumq2 ;
3499+ best = scale * sumqx ;
3500+ memcpy (L , Laux , 16 );
3501+ for (int k = 0 ; k < 2 ; ++ k ) index [k ] = aux_index [k ];
3502+ }
3503+ else break ;
34523504 }
3453- float sumqx = 0 , sumq2 = 0 ;
3454- for (int i = 0 ; i < 16 ; ++ i ) {
3455- float w = weight [i ];
3456- float q = 2 * L [i ] + 1 ;
3457- sumqx += w * xval [i ]* q ;
3458- sumq2 += w * q * q ;
3459- }
3460- if (sumq2 > 0 ) scale = sumqx /sumq2 ;
34613505 }
34623506 if (scale < 0 ) {
34633507 scale = - scale ;
@@ -3488,13 +3532,34 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_
34883532 float d = max_scale /31 ;
34893533 y [ibl ].d = GGML_FP32_TO_FP16 (d );
34903534 float id = 1 /d ;
3535+ float sumqx = 0 , sumq2 = 0 ;
34913536 for (int ib = 0 ; ib < QK_K /16 ; ++ ib ) {
34923537 int l = nearest_int (0.5f * (id * scales [ib ]- 1 ));
34933538 l = MAX (0 , MIN (15 , l ));
34943539 if (ib %2 == 0 ) y [ibl ].scales [ib /2 ] = l ;
34953540 else y [ibl ].scales [ib /2 ] |= (l << 4 );
3541+ l = 2 * l + 1 ;
3542+ const float * xb = xbl + 16 * ib ;
3543+ if (quant_weights ) {
3544+ const float * qw = quant_weights + QK_K * ibl + 16 * ib ;
3545+ for (int i = 0 ; i < 16 ; ++ i ) weight [i ] = qw [i ] * sqrtf (sigma2 + xb [i ]* xb [i ]);
3546+ } else {
3547+ for (int i = 0 ; i < 16 ; ++ i ) weight [i ] = 0.25f * sigma2 + xb [i ]* xb [i ];
3548+ }
3549+ for (int k = 0 ; k < 2 ; ++ k ) {
3550+ int grid_index = q2 [2 * ib + k ] & 511 ;
3551+ const int8_t * grid = (const int8_t * )(iq2xs_grid + grid_index );
3552+ const uint8_t signs = ksigns_iq2xs [q2 [2 * ib + k ] >> 9 ];
3553+ for (int j = 0 ; j < 8 ; ++ j ) {
3554+ float w = weight [8 * k + j ];
3555+ float q = 0.125f * l * grid [j ]* (signs & kmask_iq2xs [j ] ? -1.f : 1.f );
3556+ sumqx += w * q * xb [8 * k + j ];
3557+ sumq2 += w * q * q ;
3558+ }
3559+ }
34963560 }
34973561 memcpy (y [ibl ].qs , q2 , QK_K /4 );
3562+ if (sumq2 > 0 ) y [ibl ].d = GGML_FP32_TO_FP16 (1.05f * sumqx /sumq2 );
34983563
34993564 }
35003565}
@@ -4300,12 +4365,6 @@ static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, c
43004365 return grid_index ;
43014366}
43024367
4303- static int iq1_sort_helper (const void * left , const void * right ) {
4304- const float * l = left ;
4305- const float * r = right ;
4306- return * l < * r ? -1 : * l > * r ? 1 : 0 ;
4307- }
4308-
43094368#define IQ1S_BLOCK_SIZE 32
43104369#define IQ1M_BLOCK_SIZE 16
43114370static void quantize_row_iq1_s_impl (const float * GGML_RESTRICT x , void * GGML_RESTRICT vy , int64_t n , const float * GGML_RESTRICT quant_weights ,
0 commit comments