Skip to content

Commit 60f5ca2

Browse files
Nexesenexikawrakow
andcommitted
Improved IQ1_M quantization #327
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent d494dba commit 60f5ca2

File tree

1 file changed

+83
-85
lines changed

1 file changed

+83
-85
lines changed

ggml/src/ggml-quants.c

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4481,10 +4481,13 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
44814481
const float x_p[3] = {-1 + IQ1M_DELTA, IQ1M_DELTA, 1 + IQ1M_DELTA};
44824482
const float x_m[3] = {-1 - IQ1M_DELTA, -IQ1M_DELTA, 1 - IQ1M_DELTA};
44834483
const uint8_t masks[4] = {0x00, 0x80, 0x08, 0x88};
4484+
float all_sigma2[QK_K/32];
44844485

44854486
int * idx = (int *)(pairs + 1);
44864487

44874488
float sumqx[4], sumq2[4];
4489+
float sumw1[IQ1M_BLOCK_SIZE+1], sumw2[IQ1M_BLOCK_SIZE+1];
4490+
float sumx1[IQ1M_BLOCK_SIZE+1], sumx2[IQ1M_BLOCK_SIZE+1];
44884491

44894492
iq1m_scale_t s;
44904493
const float * xx;
@@ -4497,11 +4500,18 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
44974500
float max_scale = 0;
44984501

44994502
const float * xbl = x + QK_K*ibl;
4500-
float sumx2 = 0;
4501-
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4502-
float sigma2 = 2*sumx2/QK_K;
4503+
for (int ib = 0; ib < QK_K/32; ++ib) {
4504+
const float * xb = xbl + 32*ib;
4505+
float sumx2 = 0;
4506+
for (int i = 0; i < 32; ++i) sumx2 += xb[i]*xb[i];
4507+
all_sigma2[ib] = 1.5f*sumx2/32;
4508+
}
4509+
//float sumx2 = 0;
4510+
//for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
4511+
//float sigma2 = 1.5f*sumx2/QK_K;
45034512

45044513
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4514+
float sigma2 = all_sigma2[ib/2];
45054515
const float * xb = xbl + block_size*ib;
45064516
if (quant_weights) {
45074517
const float * qw = quant_weights + QK_K*ibl + block_size*ib;
@@ -4510,12 +4520,21 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45104520
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
45114521
}
45124522
float max = fabsf(xb[0]);
4513-
for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i]));
4523+
float sumwx = 0;
4524+
for (int i = 1; i < block_size; ++i) {
4525+
float ax = fabsf(xb[i]);
4526+
max = MAX(max, ax);
4527+
sumwx += weight[i]*ax;
4528+
}
45144529
if (max < GROUP_MAX_EPS_IQ1_M) {
45154530
scales[ib] = 0;
45164531
memset(L, 1, block_size);
45174532
continue;
45184533
}
4534+
if (sumwx == 0) {
4535+
// weight is zero everywhere where xb is not zero => ignore
4536+
for (int i = 0; i < block_size; ++i) weight[i] = xb[i]*xb[i];
4537+
}
45194538
// Here we solve exactly the sum of squared difference (SSD) weighted minimization problem.
45204539
// With just 3 allowed quant values (-1, 0, 1), we can search exhaustively for the two
45214540
// boundaries that split the weights xb[i] into 3 groups. To do so, we sort the weights
@@ -4527,82 +4546,45 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
45274546
idx[2*j] = j;
45284547
}
45294548
qsort(pairs, block_size, 2*sizeof(float), iq1_sort_helper);
4530-
float best_score = -FLT_MIN, scale = max;
4549+
sumw1[0] = sumw2[0] = sumx1[0] = sumx2[0] = 0;
4550+
for (int j = 0; j < block_size; ++j) {
4551+
int i = idx[2*j];
4552+
if (i < block_size/2) {
4553+
sumw1[j+1] = sumw1[j] + weight[i];
4554+
sumx1[j+1] = sumx1[j] + weight[i]*xb[i];
4555+
sumw2[j+1] = sumw2[j];
4556+
sumx2[j+1] = sumx2[j];
4557+
} else {
4558+
sumw2[j+1] = sumw2[j] + weight[i];
4559+
sumx2[j+1] = sumx2[j] + weight[i]*xb[i];
4560+
sumw1[j+1] = sumw1[j];
4561+
sumx1[j+1] = sumx1[j];
4562+
}
4563+
}
4564+
float best_score = 0, scale = 0.f;
45314565
int besti1 = -1, besti2 = -1, best_k = -1;
45324566
// 0: +, +
45334567
// 1: +, -
45344568
// 2: -, +
45354569
// 3: -, -
45364570
for (int i1 = 0; i1 <= block_size; ++i1) {
45374571
for (int i2 = i1; i2 <= block_size; ++i2) {
4538-
memset(sumqx, 0, 4*sizeof(float));
4539-
memset(sumq2, 0, 4*sizeof(float));
4540-
for (int j = 0; j < i1; ++j) {
4541-
int i = idx[2*j];
4542-
if (i < block_size/2) {
4543-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4544-
sumqx[1] += weight[i]*x_p[0]*xb[i];
4545-
sumqx[2] += weight[i]*x_m[0]*xb[i];
4546-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4547-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4548-
sumq2[1] += weight[i]*x_p[0]*x_p[0];
4549-
sumq2[2] += weight[i]*x_m[0]*x_m[0];
4550-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4551-
} else {
4552-
sumqx[0] += weight[i]*x_p[0]*xb[i];
4553-
sumqx[2] += weight[i]*x_p[0]*xb[i];
4554-
sumqx[1] += weight[i]*x_m[0]*xb[i];
4555-
sumqx[3] += weight[i]*x_m[0]*xb[i];
4556-
sumq2[0] += weight[i]*x_p[0]*x_p[0];
4557-
sumq2[2] += weight[i]*x_p[0]*x_p[0];
4558-
sumq2[1] += weight[i]*x_m[0]*x_m[0];
4559-
sumq2[3] += weight[i]*x_m[0]*x_m[0];
4560-
}
4561-
}
4562-
for (int j = i1; j < i2; ++j) {
4563-
int i = idx[2*j];
4564-
if (i < block_size/2) {
4565-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4566-
sumqx[1] += weight[i]*x_p[1]*xb[i];
4567-
sumqx[2] += weight[i]*x_m[1]*xb[i];
4568-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4569-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4570-
sumq2[1] += weight[i]*x_p[1]*x_p[1];
4571-
sumq2[2] += weight[i]*x_m[1]*x_m[1];
4572-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4573-
} else {
4574-
sumqx[0] += weight[i]*x_p[1]*xb[i];
4575-
sumqx[2] += weight[i]*x_p[1]*xb[i];
4576-
sumqx[1] += weight[i]*x_m[1]*xb[i];
4577-
sumqx[3] += weight[i]*x_m[1]*xb[i];
4578-
sumq2[0] += weight[i]*x_p[1]*x_p[1];
4579-
sumq2[2] += weight[i]*x_p[1]*x_p[1];
4580-
sumq2[1] += weight[i]*x_m[1]*x_m[1];
4581-
sumq2[3] += weight[i]*x_m[1]*x_m[1];
4582-
}
4583-
}
4584-
for (int j = i2; j < block_size; ++j) {
4585-
int i = idx[2*j];
4586-
if (i < block_size/2) {
4587-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4588-
sumqx[1] += weight[i]*x_p[2]*xb[i];
4589-
sumqx[2] += weight[i]*x_m[2]*xb[i];
4590-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4591-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4592-
sumq2[1] += weight[i]*x_p[2]*x_p[2];
4593-
sumq2[2] += weight[i]*x_m[2]*x_m[2];
4594-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4595-
} else {
4596-
sumqx[0] += weight[i]*x_p[2]*xb[i];
4597-
sumqx[2] += weight[i]*x_p[2]*xb[i];
4598-
sumqx[1] += weight[i]*x_m[2]*xb[i];
4599-
sumqx[3] += weight[i]*x_m[2]*xb[i];
4600-
sumq2[0] += weight[i]*x_p[2]*x_p[2];
4601-
sumq2[2] += weight[i]*x_p[2]*x_p[2];
4602-
sumq2[1] += weight[i]*x_m[2]*x_m[2];
4603-
sumq2[3] += weight[i]*x_m[2]*x_m[2];
4604-
}
4605-
}
4572+
sumqx[0] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4573+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4574+
sumqx[1] = (sumx1[i1] - sumx1[0])*x_p[0] + (sumx1[i2] - sumx1[i1])*x_p[1] + (sumx1[block_size]-sumx1[i2])*x_p[2] +
4575+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4576+
sumqx[2] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4577+
(sumx2[i1] - sumx2[0])*x_p[0] + (sumx2[i2] - sumx2[i1])*x_p[1] + (sumx2[block_size]-sumx2[i2])*x_p[2];
4578+
sumqx[3] = (sumx1[i1] - sumx1[0])*x_m[0] + (sumx1[i2] - sumx1[i1])*x_m[1] + (sumx1[block_size]-sumx1[i2])*x_m[2] +
4579+
(sumx2[i1] - sumx2[0])*x_m[0] + (sumx2[i2] - sumx2[i1])*x_m[1] + (sumx2[block_size]-sumx2[i2])*x_m[2];
4580+
sumq2[0] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4581+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4582+
sumq2[1] = (sumw1[i1] - sumw1[0])*x_p[0]*x_p[0] + (sumw1[i2] - sumw1[i1])*x_p[1]*x_p[1] + (sumw1[block_size]-sumw1[i2])*x_p[2]*x_p[2] +
4583+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
4584+
sumq2[2] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4585+
(sumw2[i1] - sumw2[0])*x_p[0]*x_p[0] + (sumw2[i2] - sumw2[i1])*x_p[1]*x_p[1] + (sumw2[block_size]-sumw2[i2])*x_p[2]*x_p[2];
4586+
sumq2[3] = (sumw1[i1] - sumw1[0])*x_m[0]*x_m[0] + (sumw1[i2] - sumw1[i1])*x_m[1]*x_m[1] + (sumw1[block_size]-sumw1[i2])*x_m[2]*x_m[2] +
4587+
(sumw2[i1] - sumw2[0])*x_m[0]*x_m[0] + (sumw2[i2] - sumw2[i1])*x_m[1]*x_m[1] + (sumw2[block_size]-sumw2[i2])*x_m[2]*x_m[2];
46064588
for (int k = 0; k < 4; ++k) {
46074589
if (sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
46084590
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k];
@@ -4636,19 +4618,34 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46364618
index[k] = grid_index;
46374619
}
46384620
if (!all_on_grid) {
4639-
float sumqx_f = 0, sumq2_f = 0;
4640-
for (int k = 0; k < block_size/8; ++k) {
4641-
if (k == 0) xx = best_k < 2 ? x_p : x_m;
4642-
else xx = best_k%2 == 0 ? x_p : x_m;
4643-
const int8_t * pg = (const int8_t *)(kgrid_q2xs + index[k]);
4644-
for (int j = 0; j < 8; ++j) {
4645-
float w = weight[8*k + j];
4646-
float q = xx[(pg[j] - 1)/2];
4647-
sumqx_f += w*q*xb[8*k+j];
4648-
sumq2_f += w*q*q;
4621+
sumqx[0] = sumqx[1] = sumqx[2] = sumqx[3] = 0;
4622+
sumq2[0] = sumq2[1] = sumq2[2] = sumq2[3] = 0;
4623+
for (int j = 0; j < block_size; ++j) {
4624+
float w = weight[j];
4625+
float qp = x_p[L[j]];
4626+
float qm = x_m[L[j]];
4627+
sumqx[0] += w*xb[j]*qp;
4628+
sumq2[0] += w*qp*qp;
4629+
sumqx[3] += w*xb[j]*qm;
4630+
sumq2[3] += w*qm*qm;
4631+
if (j < 8) {
4632+
sumqx[1] += w*xb[j]*qp;
4633+
sumq2[1] += w*qp*qp;
4634+
sumqx[2] += w*xb[j]*qm;
4635+
sumq2[2] += w*qm*qm;
4636+
} else {
4637+
sumqx[2] += w*xb[j]*qp;
4638+
sumq2[2] += w*qp*qp;
4639+
sumqx[1] += w*xb[j]*qm;
4640+
sumq2[1] += w*qm*qm;
4641+
}
4642+
}
4643+
best_score = 0;
4644+
for (int k = 0; k < 4; ++k) {
4645+
if (sumqx[k] > 0 && sumq2[k] > 0 && sumqx[k]*sumqx[k] > best_score*sumq2[k]) {
4646+
scale = sumqx[k]/sumq2[k]; best_score = scale*sumqx[k]; best_k = k;
46494647
}
46504648
}
4651-
if (sumqx_f > 0 && sumq2_f > 0) scale = sumqx_f/sumq2_f;
46524649
}
46534650
y[ibl].qs[2*ib + 0] = index[0] & 255;
46544651
y[ibl].qs[2*ib + 1] = index[1] & 255;
@@ -4668,6 +4665,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46684665
float id = 1/d;
46694666
float sumqx_f = 0, sumq2_f = 0;
46704667
for (int ib = 0; ib < QK_K/block_size; ++ib) {
4668+
float sigma2 = all_sigma2[ib/2];
46714669
int l = nearest_int(0.5f*(id*scales[ib+0]-1));
46724670
l = MAX(0, MIN(7, l));
46734671
sc[ib/4] |= (l << 3*(ib%4));
@@ -4692,7 +4690,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R
46924690
}
46934691
}
46944692
if (sumq2_f > 0) d = sumqx_f/sumq2_f;
4695-
s.f16 = GGML_FP32_TO_FP16(d*1.1125f); // 1.1125f is another fudge factor. Don't ask me why it is needed.
4693+
s.f16 = GGML_FP32_TO_FP16(d*1.085f); // 1.085f is another fudge factor. Don't ask me why it is needed.
46964694
sc[0] |= ((s.u16 & 0x000f) << 12);
46974695
sc[1] |= ((s.u16 & 0x00f0) << 8);
46984696
sc[2] |= ((s.u16 & 0x0f00) << 4);

0 commit comments

Comments
 (0)