@@ -4516,10 +4516,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45164516 const float x_p [3 ] = {-1 + IQ1M_DELTA , IQ1M_DELTA , 1 + IQ1M_DELTA };
45174517 const float x_m [3 ] = {-1 - IQ1M_DELTA , - IQ1M_DELTA , 1 - IQ1M_DELTA };
45184518 const uint8_t masks [4 ] = {0x00 , 0x80 , 0x08 , 0x88 };
4519+ float all_sigma2 [QK_K /32 ];
45194520
45204521 int * idx = (int * )(pairs + 1 );
45214522
45224523 float sumqx [4 ], sumq2 [4 ];
4524+ float sumw1 [IQ1M_BLOCK_SIZE + 1 ], sumw2 [IQ1M_BLOCK_SIZE + 1 ];
4525+ float sumx1 [IQ1M_BLOCK_SIZE + 1 ], sumx2 [IQ1M_BLOCK_SIZE + 1 ];
45234526
45244527 iq1m_scale_t s ;
45254528 const float * xx ;
@@ -4532,11 +4535,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45324535 float max_scale = 0 ;
45334536
45344537 const float * xbl = x + QK_K * ibl ;
4535- float sumx2 = 0 ;
4536- for (int i = 0 ; i < QK_K ; ++ i ) sumx2 += xbl [i ]* xbl [i ];
4537- float sigma2 = 2 * sumx2 /QK_K ;
4538+ for (int ib = 0 ; ib < QK_K /32 ; ++ ib ) {
4539+ const float * xb = xbl + 32 * ib ;
4540+ float sumx2 = 0 ;
4541+ for (int i = 0 ; i < 32 ; ++ i ) sumx2 += xb [i ]* xb [i ];
4542+ all_sigma2 [ib ] = 1.5f * sumx2 /32 ;
4543+ }
4544+ //float sumx2 = 0;
4545+ //for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4546+ //float sigma2 = 1.5f*sumx2/QK_K;
45384547
45394548 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4549+ float sigma2 = all_sigma2 [ib /2 ];
45404550 const float * xb = xbl + block_size * ib ;
45414551 if (quant_weights ) {
45424552 const float * qw = quant_weights + QK_K * ibl + block_size * ib ;
@@ -4545,12 +4555,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45454555 for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
45464556 }
45474557 float max = fabsf (xb [0 ]);
4548- for (int i = 1 ; i < block_size ; ++ i ) max = MAX (max , fabsf (xb [i ]));
4558+ float sumwx = 0 ;
4559+ for (int i = 1 ; i < block_size ; ++ i ) {
4560+ float ax = fabsf (xb [i ]);
4561+ max = MAX (max , ax );
4562+ sumwx += weight [i ]* ax ;
4563+ }
45494564 if (max < GROUP_MAX_EPS_IQ1_M ) {
45504565 scales [ib ] = 0 ;
45514566 memset (L , 1 , block_size );
45524567 continue ;
45534568 }
4569+ if (sumwx == 0 ) {
4570+ // weight is zero everywhere where xb is not zero => ignore
4571+ for (int i = 0 ; i < block_size ; ++ i ) weight [i ] = xb [i ]* xb [i ];
4572+ }
45544573 // Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45554574 // With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45564575 // boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4562,82 +4581,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45624581 idx [2 * j ] = j ;
45634582 }
45644583 qsort (pairs , block_size , 2 * sizeof (float ), iq1_sort_helper );
4565- float best_score = - FLT_MAX , scale = max ;
4584+ sumw1 [0 ] = sumw2 [0 ] = sumx1 [0 ] = sumx2 [0 ] = 0 ;
4585+ for (int j = 0 ; j < block_size ; ++ j ) {
4586+ int i = idx [2 * j ];
4587+ if (i < block_size /2 ) {
4588+ sumw1 [j + 1 ] = sumw1 [j ] + weight [i ];
4589+ sumx1 [j + 1 ] = sumx1 [j ] + weight [i ]* xb [i ];
4590+ sumw2 [j + 1 ] = sumw2 [j ];
4591+ sumx2 [j + 1 ] = sumx2 [j ];
4592+ } else {
4593+ sumw2 [j + 1 ] = sumw2 [j ] + weight [i ];
4594+ sumx2 [j + 1 ] = sumx2 [j ] + weight [i ]* xb [i ];
4595+ sumw1 [j + 1 ] = sumw1 [j ];
4596+ sumx1 [j + 1 ] = sumx1 [j ];
4597+ }
4598+ }
4599+ float best_score = 0 , scale = 0.f ;
45664600 int besti1 = -1 , besti2 = -1 , best_k = -1 ;
45674601 // 0: +, +
45684602 // 1: +, -
45694603 // 2: -, +
45704604 // 3: -, -
45714605 for (int i1 = 0 ; i1 <= block_size ; ++ i1 ) {
45724606 for (int i2 = i1 ; i2 <= block_size ; ++ i2 ) {
4573- memset (sumqx , 0 , 4 * sizeof (float ));
4574- memset (sumq2 , 0 , 4 * sizeof (float ));
4575- for (int j = 0 ; j < i1 ; ++ j ) {
4576- int i = idx [2 * j ];
4577- if (i < block_size /2 ) {
4578- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4579- sumqx [1 ] += weight [i ]* x_p [0 ]* xb [i ];
4580- sumqx [2 ] += weight [i ]* x_m [0 ]* xb [i ];
4581- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4582- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4583- sumq2 [1 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4584- sumq2 [2 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4585- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4586- } else {
4587- sumqx [0 ] += weight [i ]* x_p [0 ]* xb [i ];
4588- sumqx [2 ] += weight [i ]* x_p [0 ]* xb [i ];
4589- sumqx [1 ] += weight [i ]* x_m [0 ]* xb [i ];
4590- sumqx [3 ] += weight [i ]* x_m [0 ]* xb [i ];
4591- sumq2 [0 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4592- sumq2 [2 ] += weight [i ]* x_p [0 ]* x_p [0 ];
4593- sumq2 [1 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4594- sumq2 [3 ] += weight [i ]* x_m [0 ]* x_m [0 ];
4595- }
4596- }
4597- for (int j = i1 ; j < i2 ; ++ j ) {
4598- int i = idx [2 * j ];
4599- if (i < block_size /2 ) {
4600- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4601- sumqx [1 ] += weight [i ]* x_p [1 ]* xb [i ];
4602- sumqx [2 ] += weight [i ]* x_m [1 ]* xb [i ];
4603- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4604- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4605- sumq2 [1 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4606- sumq2 [2 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4607- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4608- } else {
4609- sumqx [0 ] += weight [i ]* x_p [1 ]* xb [i ];
4610- sumqx [2 ] += weight [i ]* x_p [1 ]* xb [i ];
4611- sumqx [1 ] += weight [i ]* x_m [1 ]* xb [i ];
4612- sumqx [3 ] += weight [i ]* x_m [1 ]* xb [i ];
4613- sumq2 [0 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4614- sumq2 [2 ] += weight [i ]* x_p [1 ]* x_p [1 ];
4615- sumq2 [1 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4616- sumq2 [3 ] += weight [i ]* x_m [1 ]* x_m [1 ];
4617- }
4618- }
4619- for (int j = i2 ; j < block_size ; ++ j ) {
4620- int i = idx [2 * j ];
4621- if (i < block_size /2 ) {
4622- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4623- sumqx [1 ] += weight [i ]* x_p [2 ]* xb [i ];
4624- sumqx [2 ] += weight [i ]* x_m [2 ]* xb [i ];
4625- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4626- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4627- sumq2 [1 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4628- sumq2 [2 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4629- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4630- } else {
4631- sumqx [0 ] += weight [i ]* x_p [2 ]* xb [i ];
4632- sumqx [2 ] += weight [i ]* x_p [2 ]* xb [i ];
4633- sumqx [1 ] += weight [i ]* x_m [2 ]* xb [i ];
4634- sumqx [3 ] += weight [i ]* x_m [2 ]* xb [i ];
4635- sumq2 [0 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4636- sumq2 [2 ] += weight [i ]* x_p [2 ]* x_p [2 ];
4637- sumq2 [1 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4638- sumq2 [3 ] += weight [i ]* x_m [2 ]* x_m [2 ];
4639- }
4640- }
4607+ sumqx [0 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4608+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4609+ sumqx [1 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_p [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_p [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_p [2 ] +
4610+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4611+ sumqx [2 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4612+ (sumx2 [i1 ] - sumx2 [0 ])* x_p [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_p [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_p [2 ];
4613+ sumqx [3 ] = (sumx1 [i1 ] - sumx1 [0 ])* x_m [0 ] + (sumx1 [i2 ] - sumx1 [i1 ])* x_m [1 ] + (sumx1 [block_size ]- sumx1 [i2 ])* x_m [2 ] +
4614+ (sumx2 [i1 ] - sumx2 [0 ])* x_m [0 ] + (sumx2 [i2 ] - sumx2 [i1 ])* x_m [1 ] + (sumx2 [block_size ]- sumx2 [i2 ])* x_m [2 ];
4615+ sumq2 [0 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4616+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4617+ sumq2 [1 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_p [0 ]* x_p [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_p [2 ]* x_p [2 ] +
4618+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
4619+ sumq2 [2 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4620+ (sumw2 [i1 ] - sumw2 [0 ])* x_p [0 ]* x_p [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_p [1 ]* x_p [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_p [2 ]* x_p [2 ];
4621+ sumq2 [3 ] = (sumw1 [i1 ] - sumw1 [0 ])* x_m [0 ]* x_m [0 ] + (sumw1 [i2 ] - sumw1 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw1 [block_size ]- sumw1 [i2 ])* x_m [2 ]* x_m [2 ] +
4622+ (sumw2 [i1 ] - sumw2 [0 ])* x_m [0 ]* x_m [0 ] + (sumw2 [i2 ] - sumw2 [i1 ])* x_m [1 ]* x_m [1 ] + (sumw2 [block_size ]- sumw2 [i2 ])* x_m [2 ]* x_m [2 ];
46414623 for (int k = 0 ; k < 4 ; ++ k ) {
46424624 if (sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
46434625 scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ];
@@ -4671,19 +4653,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46714653 index [k ] = grid_index ;
46724654 }
46734655 if (!all_on_grid ) {
4674- float sumqx_f = 0 , sumq2_f = 0 ;
4675- for (int k = 0 ; k < block_size /8 ; ++ k ) {
4676- if (k == 0 ) xx = best_k < 2 ? x_p : x_m ;
4677- else xx = best_k %2 == 0 ? x_p : x_m ;
4678- const int8_t * pg = (const int8_t * )(kgrid_q2xs + index [k ]);
4679- for (int j = 0 ; j < 8 ; ++ j ) {
4680- float w = weight [8 * k + j ];
4681- float q = xx [(pg [j ] - 1 )/2 ];
4682- sumqx_f += w * q * xb [8 * k + j ];
4683- sumq2_f += w * q * q ;
4656+ sumqx [0 ] = sumqx [1 ] = sumqx [2 ] = sumqx [3 ] = 0 ;
4657+ sumq2 [0 ] = sumq2 [1 ] = sumq2 [2 ] = sumq2 [3 ] = 0 ;
4658+ for (int j = 0 ; j < block_size ; ++ j ) {
4659+ float w = weight [j ];
4660+ float qp = x_p [L [j ]];
4661+ float qm = x_m [L [j ]];
4662+ sumqx [0 ] += w * xb [j ]* qp ;
4663+ sumq2 [0 ] += w * qp * qp ;
4664+ sumqx [3 ] += w * xb [j ]* qm ;
4665+ sumq2 [3 ] += w * qm * qm ;
4666+ if (j < 8 ) {
4667+ sumqx [1 ] += w * xb [j ]* qp ;
4668+ sumq2 [1 ] += w * qp * qp ;
4669+ sumqx [2 ] += w * xb [j ]* qm ;
4670+ sumq2 [2 ] += w * qm * qm ;
4671+ } else {
4672+ sumqx [2 ] += w * xb [j ]* qp ;
4673+ sumq2 [2 ] += w * qp * qp ;
4674+ sumqx [1 ] += w * xb [j ]* qm ;
4675+ sumq2 [1 ] += w * qm * qm ;
4676+ }
4677+ }
4678+ best_score = 0 ;
4679+ for (int k = 0 ; k < 4 ; ++ k ) {
4680+ if (sumqx [k ] > 0 && sumq2 [k ] > 0 && sumqx [k ]* sumqx [k ] > best_score * sumq2 [k ]) {
4681+ scale = sumqx [k ]/sumq2 [k ]; best_score = scale * sumqx [k ]; best_k = k ;
46844682 }
46854683 }
4686- if (sumqx_f > 0 && sumq2_f > 0 ) scale = sumqx_f /sumq2_f ;
46874684 }
46884685 y [ibl ].qs [2 * ib + 0 ] = index [0 ] & 255 ;
46894686 y [ibl ].qs [2 * ib + 1 ] = index [1 ] & 255 ;
@@ -4703,6 +4700,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47034700 float id = 1 /d ;
47044701 float sumqx_f = 0 , sumq2_f = 0 ;
47054702 for (int ib = 0 ; ib < QK_K /block_size ; ++ ib ) {
4703+ float sigma2 = all_sigma2 [ib /2 ];
47064704 int l = nearest_int (0.5f * (id * scales [ib + 0 ]- 1 ));
47074705 l = MAX (0 , MIN (7 , l ));
47084706 sc [ib /4 ] |= (l << 3 * (ib %4 ));
@@ -4727,7 +4725,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47274725 }
47284726 }
47294727 if (sumq2_f > 0 ) d = sumqx_f /sumq2_f ;
4730- s .f16 = GGML_FP32_TO_FP16 (d * 1.1125f ); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4728+ s .f16 = GGML_FP32_TO_FP16 (d * 1.085f ); // 1.085f is another fudge factor. Don't ask me why it is needed.
47314729 sc [0 ] |= ((s .u16 & 0x000f ) << 12 );
47324730 sc [1 ] |= ((s .u16 & 0x00f0 ) << 8 );
47334731 sc [2 ] |= ((s .u16 & 0x0f00 ) << 4 );
0 commit comments