Skip to content

Commit ee4c57b

Browse files
Nexesenexikawrakow
andcommitted
Quantization improvements #295 and #302, GGML part only
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent eab3894 commit ee4c57b

File tree

1 file changed

+226
-46
lines changed

1 file changed

+226
-46
lines changed

ggml/src/ggml-quants.c

Lines changed: 226 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -524,10 +524,8 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
524524
float scale = suml2 ? sumlx/suml2 : 0.0f;
525525
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
526526
float best = scale * sumlx;
527+
float best_sumlx = sumlx, best_suml2 = suml2;
527528
for (int is = -9; is <= 9; ++is) {
528-
if (is == 0) {
529-
continue;
530-
}
531529
iscale = -(nmax + 0.1f*is) / max;
532530
sumlx = suml2 = 0;
533531
for (int i = 0; i < n; ++i) {
@@ -543,7 +541,66 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
543541
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
544542
}
545543
scale = sumlx/suml2; best = scale*sumlx;
544+
best_sumlx = sumlx; best_suml2 = suml2;
545+
}
546+
iscale = (nmax-1 + 0.1f*is) / max;
547+
sumlx = suml2 = 0;
548+
for (int i = 0; i < n; ++i) {
549+
int l = nearest_int(iscale * x[i]);
550+
l = MAX(-nmax, MIN(nmax-1, l));
551+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
552+
sumlx += w*x[i]*l;
553+
suml2 += w*l*l;
554+
}
555+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
556+
for (int i = 0; i < n; ++i) {
557+
int l = nearest_int(iscale * x[i]);
558+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
559+
}
560+
scale = sumlx/suml2; best = scale*sumlx;
561+
best_sumlx = sumlx; best_suml2 = suml2;
562+
}
563+
}
564+
565+
sumlx = best_sumlx; suml2 = best_suml2;
566+
for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
567+
float abs_gmax = 0, gmax = 0;
568+
int best_j = -1;
569+
for (int j = 0; j < n; ++j) {
570+
float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
571+
int l = L[j] - nmax;
572+
float g = scale * w * (x[j] - scale*l);
573+
if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
574+
float ag = fabsf(g);
575+
if (ag > abs_gmax) {
576+
abs_gmax = ag; gmax = g; best_j = j;
577+
}
578+
}
579+
}
580+
if (best_j < 0) break;
581+
582+
float new_sumlx = sumlx, new_suml2 = suml2;
583+
float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
584+
int l = L[best_j] - nmax;
585+
if (gmax > 0) {
586+
new_sumlx += w*x[best_j];
587+
new_suml2 += w*(2*l + 1);
588+
l += 1;
589+
} else {
590+
new_sumlx -= w*x[best_j];
591+
new_suml2 -= w*(2*l - 1);
592+
l -= 1;
593+
}
594+
if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
595+
sumlx = new_sumlx; suml2 = new_suml2;
596+
scale = sumlx/suml2; best = scale*sumlx;
597+
L[best_j] = l + nmax;
598+
GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
599+
}
600+
else {
601+
break;
546602
}
603+
547604
}
548605
return scale;
549606
}
@@ -849,8 +906,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
849906
float rmin, float rdelta, int nstep, bool use_mad) {
850907
float min = x[0];
851908
float max = x[0];
852-
float sum_w = weights ? weights[0] : x[0]*x[0];
853-
float sum_x = sum_w * x[0];
909+
double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]);
910+
double sum_x = sum_w * (double)x[0];
911+
double sum_x2 = sum_w * (double)x[0] * (double)x[0];
854912
#ifdef HAVE_BUGGY_APPLE_LINKER
855913
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
856914
for (volatile int i = 1; i < n; ++i) {
@@ -860,8 +918,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
860918
if (x[i] < min) min = x[i];
861919
if (x[i] > max) max = x[i];
862920
float w = weights ? weights[i] : x[i]*x[i];
863-
sum_w += w;
864-
sum_x += w * x[i];
921+
sum_w += (double)w;
922+
sum_x += (double)w * (double)x[i];
923+
sum_x2 += (double)w * (double)x[i] * (double)x[i];
865924
}
866925
if (min > 0) {
867926
min = 0;
@@ -873,13 +932,13 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
873932
}
874933
float iscale = nmax/(max - min);
875934
float scale = 1/iscale;
876-
float best_mad = 0;
935+
double best_mad = 0;
877936
for (int i = 0; i < n; ++i) {
878937
int l = nearest_int(iscale*(x[i] - min));
879938
L[i] = MAX(0, MIN(nmax, l));
880-
float diff = scale * L[i] + min - x[i];
881-
diff = use_mad ? fabsf(diff) : diff*diff;
882-
float w = weights ? weights[i] : x[i]*x[i];
939+
double diff = (double)scale * L[i] + (double)min - (double)x[i];
940+
diff = use_mad ? fabs(diff) : diff*diff;
941+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
883942
best_mad += w * diff;
884943
}
885944
if (nstep < 1) {
@@ -888,30 +947,35 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
888947
}
889948
for (int is = 0; is <= nstep; ++is) {
890949
iscale = (rmin + rdelta*is + nmax)/(max - min);
891-
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
950+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
892951
for (int i = 0; i < n; ++i) {
893952
int l = nearest_int(iscale*(x[i] - min));
894953
l = MAX(0, MIN(nmax, l));
895954
Laux[i] = l;
896955
float w = weights ? weights[i] : x[i]*x[i];
897-
sum_l += w*l;
898-
sum_l2 += w*l*l;
899-
sum_xl += w*l*x[i];
956+
sum_l += (double)w*l;
957+
sum_l2 += (double)w*l*l;
958+
sum_xl += (double)w*l*(double)x[i];
900959
}
901-
float D = sum_w * sum_l2 - sum_l * sum_l;
960+
double D = sum_w * sum_l2 - sum_l * sum_l;
902961
if (D > 0) {
903-
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
904-
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
962+
double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
963+
double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
905964
if (this_min > 0) {
906965
this_min = 0;
907966
this_scale = sum_xl / sum_l2;
908967
}
909-
float mad = 0;
910-
for (int i = 0; i < n; ++i) {
911-
float diff = this_scale * Laux[i] + this_min - x[i];
912-
diff = use_mad ? fabsf(diff) : diff*diff;
913-
float w = weights ? weights[i] : x[i]*x[i];
914-
mad += w * diff;
968+
double mad = 0;
969+
if (use_mad) {
970+
for (int i = 0; i < n; ++i) {
971+
double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i];
972+
diff = fabs(diff);
973+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
974+
mad += w * diff;
975+
}
976+
} else {
977+
mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l
978+
+ this_scale*this_scale*sum_l2 + this_min*this_min*sum_w;
915979
}
916980
if (mad < best_mad) {
917981
for (int i = 0; i < n; ++i) {
@@ -923,6 +987,57 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
923987
}
924988
}
925989
}
990+
if (use_mad) {
991+
*the_min = -min;
992+
return scale;
993+
}
994+
995+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
996+
for (int i = 0; i < n; ++i) {
997+
int l = L[i];
998+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
999+
sum_l += w*l;
1000+
sum_l2 += w*l*l;
1001+
sum_xl += w*l*(double)x[i];
1002+
}
1003+
double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l
1004+
- (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w;
1005+
int last_j = -1, last_dir = 0;
1006+
for (int itry = 0; itry < nmax*n; ++itry) {
1007+
float gmax = 0;
1008+
int best_j = -1, dir = 0;
1009+
for (int j = 0; j < n; ++j) {
1010+
float g = x[j] - scale*L[j] - min;
1011+
if (g > 0 && L[j] < nmax && g > gmax) {
1012+
gmax = g; best_j = j; dir = 1;
1013+
}
1014+
else if (g < 0 && L[j] > 0 && -g > gmax) {
1015+
gmax = -g; best_j = j; dir = -1;
1016+
}
1017+
}
1018+
if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break;
1019+
double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]);
1020+
sum_l += w*dir;
1021+
sum_l2 += w*(2*L[best_j]*dir + 1);
1022+
sum_xl += w*(double)x[best_j]*dir;
1023+
double D = (double)sum_w * sum_l2 - sum_l * sum_l;
1024+
if (D <= 0) break;
1025+
double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D;
1026+
double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D;
1027+
if (this_min > 0) {
1028+
this_min = 0;
1029+
this_scale = sum_xl / sum_l2;
1030+
}
1031+
if (this_scale < 0) break;
1032+
double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l
1033+
- this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w;
1034+
if (score <= best) break;
1035+
best = score;
1036+
scale = this_scale;
1037+
min = this_min;
1038+
L[best_j] += dir;
1039+
last_j = best_j; last_dir = dir;
1040+
}
9261041
*the_min = -min;
9271042
return scale;
9281043
}
@@ -1004,7 +1119,7 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10041119
GGML_ASSERT(quant_weights);
10051120
assert(k % QK_K == 0);
10061121
const int nb = k / QK_K;
1007-
const bool requantize = true;
1122+
// const bool requantize = true;
10081123

10091124
uint8_t L[QK_K];
10101125
uint8_t Laux[16];
@@ -1018,39 +1133,33 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10181133
memset(sw, 0, QK_K/16*sizeof(float));
10191134
float sumx2 = 0;
10201135
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1021-
float sigma2 = sumx2/QK_K;
1136+
float sigma2 = 0.75f*sumx2/QK_K;
10221137
for (int j = 0; j < QK_K/16; ++j) {
10231138
const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;
10241139
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
10251140
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
10261141
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
10271142
}
10281143

1029-
float dm, mm;
1030-
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1031-
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1144+
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1145+
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
10321146

10331147
y[i].d = GGML_FP32_TO_FP16(dm);
10341148
y[i].dmin = GGML_FP32_TO_FP16(mm);
1035-
dm = GGML_FP16_TO_FP32(y[i].d);
1036-
mm = GGML_FP16_TO_FP32(y[i].dmin);
10371149

10381150
for (int j = 0; j < QK_K/16; ++j) {
1039-
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1040-
}
1041-
1042-
if (requantize) {
1043-
for (int j = 0; j < QK_K/16; ++j) {
1044-
const float d = dm * (y[i].scales[j] & 0xF);
1045-
if (!d) continue;
1046-
const float m = mm * (y[i].scales[j] >> 4);
1047-
for (int ii = 0; ii < 16; ++ii) {
1048-
int l = nearest_int((x[16*j + ii] + m)/d);
1049-
l = MAX(0, MIN(3, l));
1050-
L[16*j + ii] = l;
1051-
}
1151+
float d = dm*Ls[j];
1152+
float m = mm*Lm[j];
1153+
float id = d ? 1/d : 0.f;
1154+
for (int l = 0; l < QK_K/16; ++l) {
1155+
int q = nearest_int((x[16*j + l] + m)*id);
1156+
q = MAX(0, MIN(3, q));
1157+
L[16*j + l] = q;
10521158
}
10531159
}
1160+
for (int j = 0; j < QK_K/16; ++j) {
1161+
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1162+
}
10541163

10551164
for (int j = 0; j < QK_K; j += 128) {
10561165
for (int l = 0; l < 32; ++l) {
@@ -1939,8 +2048,12 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G
19392048
const int64_t nb = n_per_row/QK4_0;
19402049
for (int ib = 0; ib < nb; ++ib) {
19412050
const float * xb = x + QK4_0 * ib;
1942-
const float * qw = quant_weights + QK4_0 * ib;
1943-
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2051+
if (quant_weights) {
2052+
const float * qw = quant_weights + QK4_0 * ib;
2053+
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2054+
} else {
2055+
for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
2056+
}
19442057
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
19452058
y[ib].d = GGML_FP32_TO_FP16(d);
19462059
for (int j = 0; j < 16; ++j) {
@@ -4331,6 +4444,12 @@ static int iq1_find_best_neighbour2(const uint16_t * GGML_RESTRICT neighbours, c
43314444
return grid_index;
43324445
}
43334446

4447+
static int iq1_sort_helper(const void * left, const void * right) {
4448+
const float * l = left;
4449+
const float * r = right;
4450+
return *l < *r ? -1 : *l > *r ? 1 : 0;
4451+
}
4452+
43344453
#define IQ1S_BLOCK_SIZE 32
43354454
#define IQ1M_BLOCK_SIZE 16
43364455
static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t n, const float * GGML_RESTRICT quant_weights,
@@ -4843,6 +4962,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48434962
}
48444963
d = sumqx/sumq2;
48454964
float best = d*sumqx;
4965+
float best_sumqx = sumqx, best_sumq2 = sumq2;
48464966
for (int itry = -ntry; itry <= ntry; ++itry) {
48474967
id = (itry + values[0])/max;
48484968
sumqx = sumq2 = 0;
@@ -4856,8 +4976,68 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48564976
}
48574977
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
48584978
d = sumqx/sumq2; best = d * sumqx;
4979+
best_sumqx = sumqx; best_sumq2 = sumq2;
4980+
for (int j = 0; j < block_size; ++j) {
4981+
float al = id*xb[j];
4982+
Lb[j] = best_index_iq4nl(values, al);
4983+
}
4984+
}
4985+
id = (itry + values[15])/max;
4986+
sumqx = sumq2 = 0;
4987+
for (int j = 0; j < block_size; ++j) {
4988+
float al = id*xb[j];
4989+
int l = best_index_iq4nl(values, al);
4990+
float q = values[l];
4991+
float w = weight[j];
4992+
sumqx += w*q*xb[j];
4993+
sumq2 += w*q*q;
4994+
}
4995+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
4996+
d = sumqx/sumq2; best = d * sumqx;
4997+
best_sumqx = sumqx; best_sumq2 = sumq2;
4998+
for (int j = 0; j < block_size; ++j) {
4999+
float al = id*xb[j];
5000+
Lb[j] = best_index_iq4nl(values, al);
5001+
}
5002+
}
5003+
}
5004+
sumqx = best_sumqx; sumq2 = best_sumq2;
5005+
best_sumqx = sumqx; best_sumq2 = sumq2;
5006+
for (int iter = 0; iter < 32*block_size; ++iter) {
5007+
float min_step = INFINITY;
5008+
int best_j = -1; int dir = 0;
5009+
for (int j = 0; j < block_size; ++j) {
5010+
float w = weight[j];
5011+
float g = d * w * (xb[j] - d*values[Lb[j]]);
5012+
if (g > 0 && Lb[j] < 15) {
5013+
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
5014+
if (step < min_step) {
5015+
min_step = step; best_j = j; dir = 1;
5016+
}
5017+
}
5018+
else if (g < 0 && Lb[j] > 0) {
5019+
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
5020+
if (step < min_step) {
5021+
min_step = step; best_j = j; dir = -1;
5022+
}
5023+
}
5024+
}
5025+
if (best_j < 0) break;
5026+
5027+
float new_sumqx = sumqx, new_sumq2 = sumq2;
5028+
float w = weight[best_j];
5029+
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
5030+
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
5031+
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
5032+
sumqx = new_sumqx; sumq2 = new_sumq2;
5033+
d = sumqx/sumq2; best = d*sumqx;
5034+
Lb[best_j] += dir;
5035+
}
5036+
else {
5037+
break;
48595038
}
48605039
}
5040+
48615041
scales[ib] = d;
48625042
float abs_d = fabsf(d);
48635043
if (abs_d > amax_scale) {

0 commit comments

Comments
 (0)