Skip to content

Commit 35b0e88

Browse files
Nexesenexikawrakow
andcommitted
Improved IQ1_M quantization #327
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent 045ce78 commit 35b0e88

File tree

1 file changed

+83
-85
lines changed

1 file changed

+83
-85
lines changed

ggml/src/ggml-quants.c

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4515,10 +4515,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45154515
const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
45164516
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
45174517
const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
4518+
float all_sigma2[QK_K/32];
45184519

45194520
int * idx = (int *)(pairs + 1);
45204521

45214522
float sumqx[4], sumq2[4];
4523+
float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1];
4524+
float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1];
45224525

45234526
iq1m_scale_t s;
45244527
const float * xx;
@@ -4531,11 +4534,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45314534
float max_scale = 0;
45324535

45334536
const float * xbl = x + QK_K*ibl;
4534-
float sumx2 = 0;
4535-
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4536-
float sigma2 = 2*sumx2/QK_K;
4537+
for (int ib = 0; ib < QK_K/32; ++ib) {
4538+
const float * xb = xbl + 32*ib;
4539+
float sumx2 = 0;
4540+
for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i];
4541+
all_sigma2[ib] = 1.5f*sumx2/32;
4542+
}
4543+
//float sumx2 = 0;
4544+
//for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4545+
//float sigma2 = 1.5f*sumx2/QK_K;
45374546

45384547
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4548+
float sigma2 = all_sigma2[ib/2];
45394549
const float * xb = xbl + block_size*ib;
45404550
if (quant_weights) {
45414551
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
@@ -4544,12 +4554,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45444554
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
45454555
}
45464556
float max = fabsf(xb[0]);
4547-
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
4557+
float sumwx = 0;
4558+
for (int i = 1; i < block_size; ++i) {
4559+
float ax = fabsf(xb[i]);
4560+
max = MAX(max, ax);
4561+
sumwx += weight[i]*ax;
4562+
}
45484563
if (max < GROUP_MAX_EPS_IQ1_M) {
45494564
scales[ib] = 0;
45504565
memset(L, 1, block_size);
45514566
continue;
45524567
}
4568+
if (sumwx == 0) {
4569+
// weight is zero everywhere where xb is not zero => ignore
4570+
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
4571+
}
45534572
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45544573
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45554574
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4561,82 +4580,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45614580
idx[2*j] = j;
45624581
}
45634582
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4564-
float best_score = -FLT_MIN, scale = max;
4583+
sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0;
4584+
for (int j = 0; j < block_size; ++j) {
4585+
int i = idx[2*j];
4586+
if (i < block_size/2) {
4587+
sumw1[j+1] = sumw1[j] + weight[i];
4588+
sumx1[j+1] = sumx1[j] + weight[i]*xb[i];
4589+
sumw2[j+1] = sumw2[j];
4590+
sumx2[j+1] = sumx2[j];
4591+
} else {
4592+
sumw2[j+1] = sumw2[j] + weight[i];
4593+
sumx2[j+1] = sumx2[j] + weight[i]*xb[i];
4594+
sumw1[j+1] = sumw1[j];
4595+
sumx1[j+1] = sumx1[j];
4596+
}
4597+
}
4598+
float best_score = 0, scale = 0.f;
45654599
int besti1 = -1, besti2 = -1, best_k = -1;
45664600
// 0: +, +
45674601
// 1: +, -
45684602
// 2: -, +
45694603
// 3: -, -
45704604
for (int i1 = 0; i1 <= block_size; ++i1) {
45714605
for (int i2 = i1; i2 <= block_size; ++i2) {
4572-
memset(sumqx, 0, 4*sizeof(float));
4573-
memset(sumq2, 0, 4*sizeof(float));
4574-
for (int j = 0; j < i1; ++j) {
4575-
int i = idx[2*j];
4576-
if (i < block_size/2) {
4577-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4578-
sumqx[1] += weight[i]*x_p[0]*xb[i];
4579-
sumqx[2] += weight[i]*x_m[0]*xb[i];
4580-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4581-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4582-
sumq2[1] += weight[i]*x_p[0]*x_p[0];
4583-
sumq2[2] += weight[i]*x_m[0]*x_m[0];
4584-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4585-
} else {
4586-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4587-
sumqx[2] += weight[i]*x_p[0]*xb[i];
4588-
sumqx[1] += weight[i]*x_m[0]*xb[i];
4589-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4590-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4591-
sumq2[2] += weight[i]*x_p[0]*x_p[0];
4592-
sumq2[1] += weight[i]*x_m[0]*x_m[0];
4593-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4594-
}
4595-
}
4596-
for (int j = i1; j < i2; ++j) {
4597-
int i = idx[2*j];
4598-
if (i < block_size/2) {
4599-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4600-
sumqx[1] += weight[i]*x_p[1]*xb[i];
4601-
sumqx[2] += weight[i]*x_m[1]*xb[i];
4602-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4603-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4604-
sumq2[1] += weight[i]*x_p[1]*x_p[1];
4605-
sumq2[2] += weight[i]*x_m[1]*x_m[1];
4606-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4607-
} else {
4608-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4609-
sumqx[2] += weight[i]*x_p[1]*xb[i];
4610-
sumqx[1] += weight[i]*x_m[1]*xb[i];
4611-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4612-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4613-
sumq2[2] += weight[i]*x_p[1]*x_p[1];
4614-
sumq2[1] += weight[i]*x_m[1]*x_m[1];
4615-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4616-
}
4617-
}
4618-
for (int j = i2; j < block_size; ++j) {
4619-
int i = idx[2*j];
4620-
if (i < block_size/2) {
4621-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4622-
sumqx[1] += weight[i]*x_p[2]*xb[i];
4623-
sumqx[2] += weight[i]*x_m[2]*xb[i];
4624-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4625-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4626-
sumq2[1] += weight[i]*x_p[2]*x_p[2];
4627-
sumq2[2] += weight[i]*x_m[2]*x_m[2];
4628-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4629-
} else {
4630-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4631-
sumqx[2] += weight[i]*x_p[2]*xb[i];
4632-
sumqx[1] += weight[i]*x_m[2]*xb[i];
4633-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4634-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4635-
sumq2[2] += weight[i]*x_p[2]*x_p[2];
4636-
sumq2[1] += weight[i]*x_m[2]*x_m[2];
4637-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4638-
}
4639-
}
4606+
sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4607+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4608+
sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4609+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4610+
sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4611+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4612+
sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4613+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4614+
sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4615+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4616+
sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4617+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
4618+
sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4619+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4620+
sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4621+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
46404622
for (int k = 0; k < 4; ++k) {
46414623
if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
46424624
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
@@ -4670,19 +4652,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46704652
index[k] = grid_index;
46714653
}
46724654
if (!all_on_grid) {
4673-
float sumqx_f = 0, sumq2_f = 0;
4674-
for (int k = 0; k < block_size/8; ++k) {
4675-
if (k == 0) xx = best_k < 2 ? x_p : x_m;
4676-
else xx = best_k%2 == 0 ? x_p : x_m;
4677-
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
4678-
for (int j = 0; j < 8; ++j) {
4679-
float w = weight[8*k + j];
4680-
float q = xx[(pg[j] - 1)/2];
4681-
sumqx_f += w*q*xb[8*k+j];
4682-
sumq2_f += w*q*q;
4655+
sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0;
4656+
sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0;
4657+
for (int j = 0; j < block_size; ++j) {
4658+
float w = weight[j];
4659+
float qp = x_p[L[j]];
4660+
float qm = x_m[L[j]];
4661+
sumqx[0] += w*xb[j]*qp;
4662+
sumq2[0] += w*qp*qp;
4663+
sumqx[3] += w*xb[j]*qm;
4664+
sumq2[3] += w*qm*qm;
4665+
if (j < 8) {
4666+
sumqx[1] += w*xb[j]*qp;
4667+
sumq2[1] += w*qp*qp;
4668+
sumqx[2] += w*xb[j]*qm;
4669+
sumq2[2] += w*qm*qm;
4670+
} else {
4671+
sumqx[2] += w*xb[j]*qp;
4672+
sumq2[2] += w*qp*qp;
4673+
sumqx[1] += w*xb[j]*qm;
4674+
sumq2[1] += w*qm*qm;
4675+
}
4676+
}
4677+
best_score = 0;
4678+
for (int k = 0; k < 4; ++k) {
4679+
if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
4680+
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k;
46834681
}
46844682
}
4685-
if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
46864683
}
46874684
y[ibl].qs[2*ib + 0] = index[0] & 255;
46884685
y[ibl].qs[2*ib + 1] = index[1] & 255;
@@ -4702,6 +4699,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47024699
float id = 1/d;
47034700
float sumqx_f = 0, sumq2_f = 0;
47044701
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4702+
float sigma2 = all_sigma2[ib/2];
47054703
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
47064704
l = MAX(0, MIN(7, l));
47074705
sc[ib/4] |= (l << 3*(ib%4));
@@ -4726,7 +4724,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47264724
}
47274725
}
47284726
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
4729-
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4727+
s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
47304728
sc[0] |= ((s.u16 & 0x000f) << 12);
47314729
sc[1] |= ((s.u16 & 0x00f0) << 8);
47324730
sc[2] |= ((s.u16 & 0x0f00) << 4);

0 commit comments

Comments
 (0)