Skip to content

Commit cb7a76e

Browse files
Nexesenexikawrakow
andcommitted
Improved IQ1_M quantization #327
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent b52558e commit cb7a76e

File tree

1 file changed

+83
-85
lines changed

1 file changed

+83
-85
lines changed

ggml/src/ggml-quants.c

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4516,10 +4516,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45164516
const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
45174517
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
45184518
const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
4519+
float all_sigma2[QK_K/32];
45194520

45204521
int * idx = (int *)(pairs + 1);
45214522

45224523
float sumqx[4], sumq2[4];
4524+
float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1];
4525+
float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1];
45234526

45244527
iq1m_scale_t s;
45254528
const float * xx;
@@ -4532,11 +4535,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45324535
float max_scale = 0;
45334536

45344537
const float * xbl = x + QK_K*ibl;
4535-
float sumx2 = 0;
4536-
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4537-
float sigma2 = 2*sumx2/QK_K;
4538+
for (int ib = 0; ib < QK_K/32; ++ib) {
4539+
const float * xb = xbl + 32*ib;
4540+
float sumx2 = 0;
4541+
for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i];
4542+
all_sigma2[ib] = 1.5f*sumx2/32;
4543+
}
4544+
//float sumx2 = 0;
4545+
//for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4546+
//float sigma2 = 1.5f*sumx2/QK_K;
45384547

45394548
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4549+
float sigma2 = all_sigma2[ib/2];
45404550
const float * xb = xbl + block_size*ib;
45414551
if (quant_weights) {
45424552
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
@@ -4545,12 +4555,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45454555
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
45464556
}
45474557
float max = fabsf(xb[0]);
4548-
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
4558+
float sumwx = 0;
4559+
for (int i = 1; i < block_size; ++i) {
4560+
float ax = fabsf(xb[i]);
4561+
max = MAX(max, ax);
4562+
sumwx += weight[i]*ax;
4563+
}
45494564
if (max < GROUP_MAX_EPS_IQ1_M) {
45504565
scales[ib] = 0;
45514566
memset(L, 1, block_size);
45524567
continue;
45534568
}
4569+
if (sumwx == 0) {
4570+
// weight is zero everywhere where xb is not zero => ignore
4571+
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
4572+
}
45544573
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45554574
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45564575
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4562,82 +4581,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45624581
idx[2*j] = j;
45634582
}
45644583
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4565-
float best_score = -FLT_MAX, scale = max;
4584+
sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0;
4585+
for (int j = 0; j < block_size; ++j) {
4586+
int i = idx[2*j];
4587+
if (i < block_size/2) {
4588+
sumw1[j+1] = sumw1[j] + weight[i];
4589+
sumx1[j+1] = sumx1[j] + weight[i]*xb[i];
4590+
sumw2[j+1] = sumw2[j];
4591+
sumx2[j+1] = sumx2[j];
4592+
} else {
4593+
sumw2[j+1] = sumw2[j] + weight[i];
4594+
sumx2[j+1] = sumx2[j] + weight[i]*xb[i];
4595+
sumw1[j+1] = sumw1[j];
4596+
sumx1[j+1] = sumx1[j];
4597+
}
4598+
}
4599+
float best_score = 0, scale = 0.f;
45664600
int besti1 = -1, besti2 = -1, best_k = -1;
45674601
// 0: +, +
45684602
// 1: +, -
45694603
// 2: -, +
45704604
// 3: -, -
45714605
for (int i1 = 0; i1 <= block_size; ++i1) {
45724606
for (int i2 = i1; i2 <= block_size; ++i2) {
4573-
memset(sumqx, 0, 4*sizeof(float));
4574-
memset(sumq2, 0, 4*sizeof(float));
4575-
for (int j = 0; j < i1; ++j) {
4576-
int i = idx[2*j];
4577-
if (i < block_size/2) {
4578-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4579-
sumqx[1] += weight[i]*x_p[0]*xb[i];
4580-
sumqx[2] += weight[i]*x_m[0]*xb[i];
4581-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4582-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4583-
sumq2[1] += weight[i]*x_p[0]*x_p[0];
4584-
sumq2[2] += weight[i]*x_m[0]*x_m[0];
4585-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4586-
} else {
4587-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4588-
sumqx[2] += weight[i]*x_p[0]*xb[i];
4589-
sumqx[1] += weight[i]*x_m[0]*xb[i];
4590-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4591-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4592-
sumq2[2] += weight[i]*x_p[0]*x_p[0];
4593-
sumq2[1] += weight[i]*x_m[0]*x_m[0];
4594-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4595-
}
4596-
}
4597-
for (int j = i1; j < i2; ++j) {
4598-
int i = idx[2*j];
4599-
if (i < block_size/2) {
4600-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4601-
sumqx[1] += weight[i]*x_p[1]*xb[i];
4602-
sumqx[2] += weight[i]*x_m[1]*xb[i];
4603-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4604-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4605-
sumq2[1] += weight[i]*x_p[1]*x_p[1];
4606-
sumq2[2] += weight[i]*x_m[1]*x_m[1];
4607-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4608-
} else {
4609-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4610-
sumqx[2] += weight[i]*x_p[1]*xb[i];
4611-
sumqx[1] += weight[i]*x_m[1]*xb[i];
4612-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4613-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4614-
sumq2[2] += weight[i]*x_p[1]*x_p[1];
4615-
sumq2[1] += weight[i]*x_m[1]*x_m[1];
4616-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4617-
}
4618-
}
4619-
for (int j = i2; j < block_size; ++j) {
4620-
int i = idx[2*j];
4621-
if (i < block_size/2) {
4622-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4623-
sumqx[1] += weight[i]*x_p[2]*xb[i];
4624-
sumqx[2] += weight[i]*x_m[2]*xb[i];
4625-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4626-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4627-
sumq2[1] += weight[i]*x_p[2]*x_p[2];
4628-
sumq2[2] += weight[i]*x_m[2]*x_m[2];
4629-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4630-
} else {
4631-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4632-
sumqx[2] += weight[i]*x_p[2]*xb[i];
4633-
sumqx[1] += weight[i]*x_m[2]*xb[i];
4634-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4635-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4636-
sumq2[2] += weight[i]*x_p[2]*x_p[2];
4637-
sumq2[1] += weight[i]*x_m[2]*x_m[2];
4638-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4639-
}
4640-
}
4607+
sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4608+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4609+
sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4610+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4611+
sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4612+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4613+
sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4614+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4615+
sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4616+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4617+
sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4618+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
4619+
sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4620+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4621+
sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4622+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
46414623
for (int k = 0; k < 4; ++k) {
46424624
if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
46434625
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
@@ -4671,19 +4653,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46714653
index[k] = grid_index;
46724654
}
46734655
if (!all_on_grid) {
4674-
float sumqx_f = 0, sumq2_f = 0;
4675-
for (int k = 0; k < block_size/8; ++k) {
4676-
if (k == 0) xx = best_k < 2 ? x_p : x_m;
4677-
else xx = best_k%2 == 0 ? x_p : x_m;
4678-
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
4679-
for (int j = 0; j < 8; ++j) {
4680-
float w = weight[8*k + j];
4681-
float q = xx[(pg[j] - 1)/2];
4682-
sumqx_f += w*q*xb[8*k+j];
4683-
sumq2_f += w*q*q;
4656+
sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0;
4657+
sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0;
4658+
for (int j = 0; j < block_size; ++j) {
4659+
float w = weight[j];
4660+
float qp = x_p[L[j]];
4661+
float qm = x_m[L[j]];
4662+
sumqx[0] += w*xb[j]*qp;
4663+
sumq2[0] += w*qp*qp;
4664+
sumqx[3] += w*xb[j]*qm;
4665+
sumq2[3] += w*qm*qm;
4666+
if (j < 8) {
4667+
sumqx[1] += w*xb[j]*qp;
4668+
sumq2[1] += w*qp*qp;
4669+
sumqx[2] += w*xb[j]*qm;
4670+
sumq2[2] += w*qm*qm;
4671+
} else {
4672+
sumqx[2] += w*xb[j]*qp;
4673+
sumq2[2] += w*qp*qp;
4674+
sumqx[1] += w*xb[j]*qm;
4675+
sumq2[1] += w*qm*qm;
4676+
}
4677+
}
4678+
best_score = 0;
4679+
for (int k = 0; k < 4; ++k) {
4680+
if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
4681+
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k;
46844682
}
46854683
}
4686-
if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
46874684
}
46884685
y[ibl].qs[2*ib + 0] = index[0] & 255;
46894686
y[ibl].qs[2*ib + 1] = index[1] & 255;
@@ -4703,6 +4700,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47034700
float id = 1/d;
47044701
float sumqx_f = 0, sumq2_f = 0;
47054702
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4703+
float sigma2 = all_sigma2[ib/2];
47064704
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
47074705
l = MAX(0, MIN(7, l));
47084706
sc[ib/4] |= (l << 3*(ib%4));
@@ -4727,7 +4725,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
47274725
}
47284726
}
47294727
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
4730-
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4728+
s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
47314729
sc[0] |= ((s.u16 & 0x000f) << 12);
47324730
sc[1] |= ((s.u16 & 0x00f0) << 8);
47334731
sc[2] |= ((s.u16 & 0x0f00) << 4);

0 commit comments

Comments
 (0)