@@ -4515,10 +4515,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45154515 const float x_p [3 ] = {-1 + IQ1M_DELTA , IQ1M_DELTA , 1 + IQ1M_DELTA };
45164516 const float x_m [3 ] = {-1 - IQ1M_DELTA , - IQ1M_DELTA , 1 - IQ1M_DELTA };
45174517 const uint8_t masks [4 ] = {0x00 , 0x80 , 0x08 , 0x88 };
4518+ float all_sigma2 [QK_K /32 ];
45184519
45194520 int * idx = (int * )(pairs + 1 );
45204521
45214522 float sumqx [4 ], sumq2 [4 ];
4523+ float sumw1 [IQ1M_BLOCK_SIZE + 1 ], sumw2 [IQ1M_BLOCK_SIZE + 1 ];
4524+ float sumx1 [IQ1M_BLOCK_SIZE + 1 ], sumx2 [IQ1M_BLOCK_SIZE + 1 ];
45224525
45234526 iq1m_scale_t s ;
45244527 const float * xx ;
@@ -4531,11 +4534,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45314534 float max_scale = 0 ;
45324535
45334536 const float * xbl = x + QK_K * ibl ;
4534- float sumx2 = 0 ;
4535- for (int i = 0 ; i < QK_K ; ++ i ) sumx2 += xbl [i ]* xbl [i ];
4536- float sigma2 = 2 * sumx2 /QK_K ;
4537+ for (int ib = 0 ; ib < QK_K /32 ; ++ ib ) {
4538+ const float * xb = xbl + 32 * ib ;
4539+ float sumx2 = 0 ;
4540+ for (int i = 0 ; i < 32 ; ++ i ) sumx2 += xb [i ]* xb [i ];
4541+ all_sigma2 [ib ] = 1.5f * sumx2 /32 ;
4542+ }
4543+ //float sumx2 = 0;
4544+ //for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4545+ //float sigma2 = 1.5f*sumx2/QK_K;
45374546
45384547 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4548+ float sigma2 = all_sigma2 [ib /2 ];
45394549 const float * xb = xbl + block_size * ib ;
45404550 if (quant_weights ) {
45414551 const float * qw = quant_weights + QK_K * ibl + block_size * ib ;
@@ -4544,12 +4554,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45444554 for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
45454555 }
45464556 float max = fabsf (xb [0 ]);
4547- for (int i = 1 ; i < block_size ; ++ i ) max = MAX (max , fabsf (xb [i ]));
4557+ float sumwx = 0 ;
4558+ for (int i = 1 ; i < block_size ; ++ i ) {
4559+ float ax = fabsf (xb [i ]);
4560+ max = MAX (max , ax );
4561+ sumwx += weight [i ]* ax ;
4562+ }
45484563 if (max < GROUP_MAX_EPS_IQ1_M ) {
45494564 scales [ib ] = 0 ;
45504565 memset (L , 1 , block_size );
45514566 continue ;
45524567 }
4568+ if (sumwx == 0 ) {
4569+ // weight is zero everywhere where xb is not zero => ignore
4570+ for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
4571+ }
45534572 // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45544573 // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45554574 // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4561,82 +4580,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45614580 idx [2 * j ] = j ;
45624581 }
45634582 qsort (pairs , block_size , 2 * sizeof (float ), iq1_sort_helper );
4564- float best_score = - FLT_MIN , scale = max ;
4583+ sumw1 [0 ] = sumw2 [0 ] = sumx1 [0 ] = sumx2 [0 ] = 0 ;
4584+ for (int j = 0 ; j < block_size ; ++ j ) {
4585+ int i = idx [2 * j ];
4586+ if (i < block_size /2 ) {
4587+ sumw1 [j + 1 ] = sumw1 [j ] + weight [i ];
4588+ sumx1 [j + 1 ] = sumx1 [j ] + weight [i ]* xb [i ];
4589+ sumw2 [j + 1 ] = sumw2 [j ];
4590+ sumx2 [j + 1 ] = sumx2 [j ];
4591+ } else {
4592+ sumw2 [j + 1 ] = sumw2 [j ] + weight [i ];
4593+ sumx2 [j + 1 ] = sumx2 [j ] + weight [i ]* xb [i ];
4594+ sumw1 [j + 1 ] = sumw1 [j ];
4595+ sumx1 [j + 1 ] = sumx1 [j ];
4596+ }
4597+ }
4598+ float best_score = 0 , scale = 0.f ;
45654599 int besti1 = -1 , besti2 = -1 , best_k = -1 ;
45664600 // 0: +, +
45674601 // 1: +, -
45684602 // 2: -, +
45694603 // 3: -, -
45704604 for (int i1 = 0 ; i1 <= block_size ; ++ i1 ) {
45714605 for (int i2 = i1 ; i2 <= block_size ; ++ i2 ) {
4572- memset (sumqx , 0 , 4 * sizeof (float ));
4573- memset (sumq2 , 0 , 4 * sizeof (float ));
4574- for (int j = 0 ; j < i1 ; ++ j ) {
4575- int i = idx [2 * j ];
4576- if (i < block_size /2 ) {
4577- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4578- sumqx [1 ] += weight [i ]* x_p [0 ]* xb [i ];
4579- sumqx [2 ] += weight [i ]* x_m [0 ]* xb [i ];
4580- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4581- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4582- sumq2 [1 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4583- sumq2 [2 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4584- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4585- } else {
4586- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4587- sumqx [2 ] += weight [i ]* x_p [0 ]* xb [i ];
4588- sumqx [1 ] += weight [i ]* x_m [0 ]* xb [i ];
4589- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4590- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4591- sumq2 [2 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4592- sumq2 [1 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4593- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4594- }
4595- }
4596- for (int j = i1 ; j < i2 ; ++ j ) {
4597- int i = idx [2 * j ];
4598- if (i < block_size /2 ) {
4599- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4600- sumqx [1 ] += weight [i ]* x_p [1 ]* xb [i ];
4601- sumqx [2 ] += weight [i ]* x_m [1 ]* xb [i ];
4602- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4603- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4604- sumq2 [1 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4605- sumq2 [2 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4606- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4607- } else {
4608- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4609- sumqx [2 ] += weight [i ]* x_p [1 ]* xb [i ];
4610- sumqx [1 ] += weight [i ]* x_m [1 ]* xb [i ];
4611- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4612- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4613- sumq2 [2 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4614- sumq2 [1 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4615- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4616- }
4617- }
4618- for (int j = i2 ; j < block_size ; ++ j ) {
4619- int i = idx [2 * j ];
4620- if (i < block_size /2 ) {
4621- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4622- sumqx [1 ] += weight [i ]* x_p [2 ]* xb [i ];
4623- sumqx [2 ] += weight [i ]* x_m [2 ]* xb [i ];
4624- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4625- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4626- sumq2 [1 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4627- sumq2 [2 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4628- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4629- } else {
4630- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4631- sumqx [2 ] += weight [i ]* x_p [2 ]* xb [i ];
4632- sumqx [1 ] += weight [i ]* x_m [2 ]* xb [i ];
4633- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4634- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4635- sumq2 [2 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4636- sumq2 [1 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4637- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4638- }
4639- }
4606+ sumqx [0 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4607+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4608+ sumqx [1 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4609+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4610+ sumqx [2 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4611+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4612+ sumqx [3 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4613+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4614+ sumq2 [0 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4615+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4616+ sumq2 [1 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4617+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
4618+ sumq2 [2 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4619+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4620+ sumq2 [3 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4621+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
46404622 for (int k = 0 ; k < 4 ; ++ k ) {
46414623 if (sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
46424624 scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ];
@@ -4670,19 +4652,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46704652 index [k ] = grid_index ;
46714653 }
46724654 if (!all_on_grid ) {
4673- float sumqx_f = 0 , sumq2_f = 0 ;
4674- for (int k = 0 ; k < block_size /8 ; ++ k ) {
4675- if (k == 0 ) xx = best_k < 2 ? x_p : x_m ;
4676- else xx = best_k %2 == 0 ? x_p : x_m ;
4677- const int8_t * pg = (const int8_t * )(kgrid_q2xs + index [k ]);
4678- for (int j = 0 ; j < 8 ; ++ j ) {
4679- float w = weight [8 * k + j ];
4680- float q = xx [(pg [j ] - 1 )/2 ];
4681- sumqx_f += w * q * xb [8 * k + j ];
4682- sumq2_f += w * q * q ;
4655+ sumqx [0 ] = sumqx [1 ] = sumqx [2 ] = sumqx [3 ] = 0 ;
4656+ sumq2 [0 ] = sumq2 [1 ] = sumq2 [2 ] = sumq2 [3 ] = 0 ;
4657+ for (int j = 0 ; j < block_size ; ++ j ) {
4658+ float w = weight [j ];
4659+ float qp = x_p [L [j ]];
4660+ float qm = x_m [L [j ]];
4661+ sumqx [0 ] += w * xb [j ]* qp ;
4662+ sumq2 [0 ] += w * qp * qp ;
4663+ sumqx [3 ] += w * xb [j ]* qm ;
4664+ sumq2 [3 ] += w * qm * qm ;
4665+ if (j < 8 ) {
4666+ sumqx [1 ] += w * xb [j ]* qp ;
4667+ sumq2 [1 ] += w * qp * qp ;
4668+ sumqx [2 ] += w * xb [j ]* qm ;
4669+ sumq2 [2 ] += w * qm * qm ;
4670+ } else {
4671+ sumqx [2 ] += w * xb [j ]* qp ;
4672+ sumq2 [2 ] += w * qp * qp ;
4673+ sumqx [1 ] += w * xb [j ]* qm ;
4674+ sumq2 [1 ] += w * qm * qm ;
4675+ }
4676+ }
4677+ best_score = 0 ;
4678+ for (int k = 0 ; k < 4 ; ++ k ) {
4679+ if (sumqx [k ] > 0 && sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
4680+ scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ]; best_k = k ;
46834681 }
46844682 }
4685- if (sumqx_f > 0 && sumq2_f > 0 ) scale = sumqx_f /sumq2_f ;
46864683 }
46874684 y [ibl ].qs [2 * ib + 0 ] = index [0 ] & 255 ;
46884685 y [ibl ].qs [2 * ib + 1 ] = index [1 ] & 255 ;
@@ -4702,6 +4699,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47024699 float id = 1 /d ;
47034700 float sumqx_f = 0 , sumq2_f = 0 ;
47044701 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4702+ float sigma2 = all_sigma2 [ib /2 ];
47054703 int l = nearest_int (0.5f * (id * scales [ib + 0 ]- 1 ));
47064704 l = MAX (0 , MIN (7 , l ));
47074705 sc [ib /4 ] |= (l << 3 * (ib %4 ));
@@ -4726,7 +4724,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47264724 }
47274725 }
47284726 if (sumq2_f > 0 ) d = sumqx_f /sumq2_f ;
4729- s .f16 = GGML_FP32_TO_FP16 (d * 1.1125f ); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4727+ s .f16 = GGML_FP32_TO_FP16 (d * 1.085f ); // 1.085f is another fudge factor. Don't ask me why it is needed.
47304728 sc [0 ] |= ((s .u16 & 0x000f ) << 12 );
47314729 sc [1 ] |= ((s .u16 & 0x00f0 ) << 8 );
47324730 sc [2 ] |= ((s .u16 & 0x0f00 ) << 4 );
0 commit comments