@@ -4481,10 +4481,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
44814481 const float x_p [3 ] = {-1 + IQ1M_DELTA , IQ1M_DELTA , 1 + IQ1M_DELTA };
44824482 const float x_m [3 ] = {-1 - IQ1M_DELTA , - IQ1M_DELTA , 1 - IQ1M_DELTA };
44834483 const uint8_t masks [4 ] = {0x00 , 0x80 , 0x08 , 0x88 };
4484+ float all_sigma2 [QK_K /32 ];
44844485
44854486 int * idx = (int * )(pairs + 1 );
44864487
44874488 float sumqx [4 ], sumq2 [4 ];
4489+ float sumw1 [IQ1M_BLOCK_SIZE + 1 ], sumw2 [IQ1M_BLOCK_SIZE + 1 ];
4490+ float sumx1 [IQ1M_BLOCK_SIZE + 1 ], sumx2 [IQ1M_BLOCK_SIZE + 1 ];
44884491
44894492 iq1m_scale_t s ;
44904493 const float * xx ;
@@ -4497,11 +4500,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
44974500 float max_scale = 0 ;
44984501
44994502 const float * xbl = x + QK_K * ibl ;
4500- float sumx2 = 0 ;
4501- for (int i = 0 ; i < QK_K ; ++ i ) sumx2 += xbl [i ]* xbl [i ];
4502- float sigma2 = 2 * sumx2 /QK_K ;
4503+ for (int ib = 0 ; ib < QK_K /32 ; ++ ib ) {
4504+ const float * xb = xbl + 32 * ib ;
4505+ float sumx2 = 0 ;
4506+ for (int i = 0 ; i < 32 ; ++ i ) sumx2 += xb [i ]* xb [i ];
4507+ all_sigma2 [ib ] = 1.5f * sumx2 /32 ;
4508+ }
4509+ //float sumx2 = 0;
4510+ //for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4511+ //float sigma2 = 1.5f*sumx2/QK_K;
45034512
45044513 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4514+ float sigma2 = all_sigma2 [ib /2 ];
45054515 const float * xb = xbl + block_size * ib ;
45064516 if (quant_weights ) {
45074517 const float * qw = quant_weights + QK_K * ibl + block_size * ib ;
@@ -4510,12 +4520,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45104520 for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
45114521 }
45124522 float max = fabsf (xb [0 ]);
4513- for (int i = 1 ; i < block_size ; ++ i ) max = MAX (max , fabsf (xb [i ]));
4523+ float sumwx = 0 ;
4524+ for (int i = 1 ; i < block_size ; ++ i ) {
4525+ float ax = fabsf (xb [i ]);
4526+ max = MAX (max , ax );
4527+ sumwx += weight [i ]* ax ;
4528+ }
45144529 if (max < GROUP_MAX_EPS_IQ1_M ) {
45154530 scales [ib ] = 0 ;
45164531 memset (L , 1 , block_size );
45174532 continue ;
45184533 }
4534+ if (sumwx == 0 ) {
4535+ // weight is zero everywhere where xb is not zero => ignore
4536+ for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
4537+ }
45194538 // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45204539 // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45214540 // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4527,82 +4546,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45274546 idx [2 * j ] = j ;
45284547 }
45294548 qsort (pairs , block_size , 2 * sizeof (float ), iq1_sort_helper );
4530- float best_score = - FLT_MIN , scale = max ;
4549+ sumw1 [0 ] = sumw2 [0 ] = sumx1 [0 ] = sumx2 [0 ] = 0 ;
4550+ for (int j = 0 ; j < block_size ; ++ j ) {
4551+ int i = idx [2 * j ];
4552+ if (i < block_size /2 ) {
4553+ sumw1 [j + 1 ] = sumw1 [j ] + weight [i ];
4554+ sumx1 [j + 1 ] = sumx1 [j ] + weight [i ]* xb [i ];
4555+ sumw2 [j + 1 ] = sumw2 [j ];
4556+ sumx2 [j + 1 ] = sumx2 [j ];
4557+ } else {
4558+ sumw2 [j + 1 ] = sumw2 [j ] + weight [i ];
4559+ sumx2 [j + 1 ] = sumx2 [j ] + weight [i ]* xb [i ];
4560+ sumw1 [j + 1 ] = sumw1 [j ];
4561+ sumx1 [j + 1 ] = sumx1 [j ];
4562+ }
4563+ }
4564+ float best_score = 0 , scale = 0.f ;
45314565 int besti1 = -1 , besti2 = -1 , best_k = -1 ;
45324566 // 0: +, +
45334567 // 1: +, -
45344568 // 2: -, +
45354569 // 3: -, -
45364570 for (int i1 = 0 ; i1 <= block_size ; ++ i1 ) {
45374571 for (int i2 = i1 ; i2 <= block_size ; ++ i2 ) {
4538- memset (sumqx , 0 , 4 * sizeof (float ));
4539- memset (sumq2 , 0 , 4 * sizeof (float ));
4540- for (int j = 0 ; j < i1 ; ++ j ) {
4541- int i = idx [2 * j ];
4542- if (i < block_size /2 ) {
4543- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4544- sumqx [1 ] += weight [i ]* x_p [0 ]* xb [i ];
4545- sumqx [2 ] += weight [i ]* x_m [0 ]* xb [i ];
4546- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4547- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4548- sumq2 [1 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4549- sumq2 [2 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4550- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4551- } else {
4552- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4553- sumqx [2 ] += weight [i ]* x_p [0 ]* xb [i ];
4554- sumqx [1 ] += weight [i ]* x_m [0 ]* xb [i ];
4555- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4556- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4557- sumq2 [2 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4558- sumq2 [1 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4559- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4560- }
4561- }
4562- for (int j = i1 ; j < i2 ; ++ j ) {
4563- int i = idx [2 * j ];
4564- if (i < block_size /2 ) {
4565- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4566- sumqx [1 ] += weight [i ]* x_p [1 ]* xb [i ];
4567- sumqx [2 ] += weight [i ]* x_m [1 ]* xb [i ];
4568- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4569- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4570- sumq2 [1 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4571- sumq2 [2 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4572- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4573- } else {
4574- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4575- sumqx [2 ] += weight [i ]* x_p [1 ]* xb [i ];
4576- sumqx [1 ] += weight [i ]* x_m [1 ]* xb [i ];
4577- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4578- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4579- sumq2 [2 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4580- sumq2 [1 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4581- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4582- }
4583- }
4584- for (int j = i2 ; j < block_size ; ++ j ) {
4585- int i = idx [2 * j ];
4586- if (i < block_size /2 ) {
4587- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4588- sumqx [1 ] += weight [i ]* x_p [2 ]* xb [i ];
4589- sumqx [2 ] += weight [i ]* x_m [2 ]* xb [i ];
4590- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4591- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4592- sumq2 [1 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4593- sumq2 [2 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4594- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4595- } else {
4596- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4597- sumqx [2 ] += weight [i ]* x_p [2 ]* xb [i ];
4598- sumqx [1 ] += weight [i ]* x_m [2 ]* xb [i ];
4599- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4600- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4601- sumq2 [2 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4602- sumq2 [1 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4603- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4604- }
4605- }
4572+ sumqx [0 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4573+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4574+ sumqx [1 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4575+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4576+ sumqx [2 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4577+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4578+ sumqx [3 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4579+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4580+ sumq2 [0 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4581+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4582+ sumq2 [1 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4583+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
4584+ sumq2 [2 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4585+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4586+ sumq2 [3 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4587+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
46064588 for (int k = 0 ; k < 4 ; ++ k ) {
46074589 if (sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
46084590 scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ];
@@ -4636,19 +4618,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46364618 index [k ] = grid_index ;
46374619 }
46384620 if (!all_on_grid ) {
4639- float sumqx_f = 0 , sumq2_f = 0 ;
4640- for (int k = 0 ; k < block_size /8 ; ++ k ) {
4641- if (k == 0 ) xx = best_k < 2 ? x_p : x_m ;
4642- else xx = best_k %2 == 0 ? x_p : x_m ;
4643- const int8_t * pg = (const int8_t * )(kgrid_q2xs + index [k ]);
4644- for (int j = 0 ; j < 8 ; ++ j ) {
4645- float w = weight [8 * k + j ];
4646- float q = xx [(pg [j ] - 1 )/2 ];
4647- sumqx_f += w * q * xb [8 * k + j ];
4648- sumq2_f += w * q * q ;
4621+ sumqx [0 ] = sumqx [1 ] = sumqx [2 ] = sumqx [3 ] = 0 ;
4622+ sumq2 [0 ] = sumq2 [1 ] = sumq2 [2 ] = sumq2 [3 ] = 0 ;
4623+ for (int j = 0 ; j < block_size ; ++ j ) {
4624+ float w = weight [j ];
4625+ float qp = x_p [L [j ]];
4626+ float qm = x_m [L [j ]];
4627+ sumqx [0 ] += w * xb [j ]* qp ;
4628+ sumq2 [0 ] += w * qp * qp ;
4629+ sumqx [3 ] += w * xb [j ]* qm ;
4630+ sumq2 [3 ] += w * qm * qm ;
4631+ if (j < 8 ) {
4632+ sumqx [1 ] += w * xb [j ]* qp ;
4633+ sumq2 [1 ] += w * qp * qp ;
4634+ sumqx [2 ] += w * xb [j ]* qm ;
4635+ sumq2 [2 ] += w * qm * qm ;
4636+ } else {
4637+ sumqx [2 ] += w * xb [j ]* qp ;
4638+ sumq2 [2 ] += w * qp * qp ;
4639+ sumqx [1 ] += w * xb [j ]* qm ;
4640+ sumq2 [1 ] += w * qm * qm ;
4641+ }
4642+ }
4643+ best_score = 0 ;
4644+ for (int k = 0 ; k < 4 ; ++ k ) {
4645+ if (sumqx [k ] > 0 && sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
4646+ scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ]; best_k = k ;
46494647 }
46504648 }
4651- if (sumqx_f > 0 && sumq2_f > 0 ) scale = sumqx_f /sumq2_f ;
46524649 }
46534650 y [ibl ].qs [2 * ib + 0 ] = index [0 ] & 255 ;
46544651 y [ibl ].qs [2 * ib + 1 ] = index [1 ] & 255 ;
@@ -4668,6 +4665,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46684665 float id = 1 /d ;
46694666 float sumqx_f = 0 , sumq2_f = 0 ;
46704667 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4668+ float sigma2 = all_sigma2 [ib /2 ];
46714669 int l = nearest_int (0.5f * (id * scales [ib + 0 ]- 1 ));
46724670 l = MAX (0 , MIN (7 , l ));
46734671 sc [ib /4 ] |= (l << 3 * (ib %4 ));
@@ -4692,7 +4690,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46924690 }
46934691 }
46944692 if (sumq2_f > 0 ) d = sumqx_f /sumq2_f ;
4695- s .f16 = GGML_FP32_TO_FP16 (d * 1.1125f ); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4693+ s .f16 = GGML_FP32_TO_FP16 (d * 1.085f ); // 1.085f is another fudge factor. Don't ask me why it is needed.
46964694 sc [0 ] |= ((s .u16 & 0x000f ) << 12 );
46974695 sc [1 ] |= ((s .u16 & 0x00f0 ) << 8 );
46984696 sc [2 ] |= ((s .u16 & 0x0f00 ) << 4 );
0 commit comments