Skip to content

Commit c93034c

Browse files
Nexesenexikawrakow
andcommitted
Quantization improvements #295 and #302, GGML part only
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent 997f1bb commit c93034c

File tree

1 file changed

+220
-46
lines changed

1 file changed

+220
-46
lines changed

ggml/src/ggml-quants.c

Lines changed: 220 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,8 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
568568
float scale = suml2 ? sumlx/suml2 : 0.0f;
569569
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
570570
float best = scale * sumlx;
571+
float best_sumlx = sumlx, best_suml2 = suml2;
571572
for (int is = -9; is <= 9; ++is) {
572-
if (is == 0) {
573-
continue;
574-
}
575573
iscale = -(nmax + 0.1f*is) / max;
576574
sumlx = suml2 = 0;
577575
for (int i = 0; i < n; ++i) {
@@ -587,7 +585,66 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
587585
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
588586
}
589587
scale = sumlx/suml2; best = scale*sumlx;
588+
best_sumlx = sumlx; best_suml2 = suml2;
589+
}
590+
iscale = (nmax-1 + 0.1f*is) / max;
591+
sumlx = suml2 = 0;
592+
for (int i = 0; i < n; ++i) {
593+
int l = nearest_int(iscale * x[i]);
594+
l = MAX(-nmax, MIN(nmax-1, l));
595+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
596+
sumlx += w*x[i]*l;
597+
suml2 += w*l*l;
590598
}
599+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
600+
for (int i = 0; i < n; ++i) {
601+
int l = nearest_int(iscale * x[i]);
602+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
603+
}
604+
scale = sumlx/suml2; best = scale*sumlx;
605+
best_sumlx = sumlx; best_suml2 = suml2;
606+
}
607+
}
608+
609+
sumlx = best_sumlx; suml2 = best_suml2;
610+
for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
611+
float abs_gmax = 0, gmax = 0;
612+
int best_j = -1;
613+
for (int j = 0; j < n; ++j) {
614+
float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
615+
int l = L[j] - nmax;
616+
float g = scale * w * (x[j] - scale*l);
617+
if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
618+
float ag = fabsf(g);
619+
if (ag > abs_gmax) {
620+
abs_gmax = ag; gmax = g; best_j = j;
621+
}
622+
}
623+
}
624+
if (best_j < 0) break;
625+
626+
float new_sumlx = sumlx, new_suml2 = suml2;
627+
float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
628+
int l = L[best_j] - nmax;
629+
if (gmax > 0) {
630+
new_sumlx += w*x[best_j];
631+
new_suml2 += w*(2*l + 1);
632+
l += 1;
633+
} else {
634+
new_sumlx -= w*x[best_j];
635+
new_suml2 -= w*(2*l - 1);
636+
l -= 1;
637+
}
638+
if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
639+
sumlx = new_sumlx; suml2 = new_suml2;
640+
scale = sumlx/suml2; best = scale*sumlx;
641+
L[best_j] = l + nmax;
642+
GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
643+
}
644+
else {
645+
break;
646+
}
647+
591648
}
592649
return scale;
593650
}
@@ -893,8 +950,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
893950
float rmin, float rdelta, int nstep, bool use_mad) {
894951
float min = x[0];
895952
float max = x[0];
896-
float sum_w = weights ? weights[0] : x[0]*x[0];
897-
float sum_x = sum_w * x[0];
953+
double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]);
954+
double sum_x = sum_w * (double)x[0];
955+
double sum_x2 = sum_w * (double)x[0] * (double)x[0];
898956
#ifdef HAVE_BUGGY_APPLE_LINKER
899957
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
900958
for (volatile int i = 1; i < n; ++i) {
@@ -904,8 +962,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
904962
if (x[i] < min) min = x[i];
905963
if (x[i] > max) max = x[i];
906964
float w = weights ? weights[i] : x[i]*x[i];
907-
sum_w += w;
908-
sum_x += w * x[i];
965+
sum_w += (double)w;
966+
sum_x += (double)w * (double)x[i];
967+
sum_x2 += (double)w * (double)x[i] * (double)x[i];
909968
}
910969
if (min > 0) {
911970
min = 0;
@@ -917,13 +976,13 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
917976
}
918977
float iscale = nmax/(max - min);
919978
float scale = 1/iscale;
920-
float best_mad = 0;
979+
double best_mad = 0;
921980
for (int i = 0; i < n; ++i) {
922981
int l = nearest_int(iscale*(x[i] - min));
923982
L[i] = MAX(0, MIN(nmax, l));
924-
float diff = scale * L[i] + min - x[i];
925-
diff = use_mad ? fabsf(diff) : diff*diff;
926-
float w = weights ? weights[i] : x[i]*x[i];
983+
double diff = (double)scale * L[i] + (double)min - (double)x[i];
984+
diff = use_mad ? fabs(diff) : diff*diff;
985+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
927986
best_mad += w * diff;
928987
}
929988
if (nstep < 1) {
@@ -932,30 +991,35 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
932991
}
933992
for (int is = 0; is <= nstep; ++is) {
934993
iscale = (rmin + rdelta*is + nmax)/(max - min);
935-
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
994+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
936995
for (int i = 0; i < n; ++i) {
937996
int l = nearest_int(iscale*(x[i] - min));
938997
l = MAX(0, MIN(nmax, l));
939998
Laux[i] = l;
940999
float w = weights ? weights[i] : x[i]*x[i];
941-
sum_l += w*l;
942-
sum_l2 += w*l*l;
943-
sum_xl += w*l*x[i];
1000+
sum_l += (double)w*l;
1001+
sum_l2 += (double)w*l*l;
1002+
sum_xl += (double)w*l*(double)x[i];
9441003
}
945-
float D = sum_w * sum_l2 - sum_l * sum_l;
1004+
double D = sum_w * sum_l2 - sum_l * sum_l;
9461005
if (D > 0) {
947-
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
948-
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1006+
double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1007+
double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
9491008
if (this_min > 0) {
9501009
this_min = 0;
9511010
this_scale = sum_xl / sum_l2;
9521011
}
953-
float mad = 0;
954-
for (int i = 0; i < n; ++i) {
955-
float diff = this_scale * Laux[i] + this_min - x[i];
956-
diff = use_mad ? fabsf(diff) : diff*diff;
957-
float w = weights ? weights[i] : x[i]*x[i];
958-
mad += w * diff;
1012+
double mad = 0;
1013+
if (use_mad) {
1014+
for (int i = 0; i < n; ++i) {
1015+
double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i];
1016+
diff = fabs(diff);
1017+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
1018+
mad += w * diff;
1019+
}
1020+
} else {
1021+
mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l
1022+
+ this_scale*this_scale*sum_l2 + this_min*this_min*sum_w;
9591023
}
9601024
if (mad < best_mad) {
9611025
for (int i = 0; i < n; ++i) {
@@ -967,6 +1031,57 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
9671031
}
9681032
}
9691033
}
1034+
if (use_mad) {
1035+
*the_min = -min;
1036+
return scale;
1037+
}
1038+
1039+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
1040+
for (int i = 0; i < n; ++i) {
1041+
int l = L[i];
1042+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
1043+
sum_l += w*l;
1044+
sum_l2 += w*l*l;
1045+
sum_xl += w*l*(double)x[i];
1046+
}
1047+
double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l
1048+
- (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w;
1049+
int last_j = -1, last_dir = 0;
1050+
for (int itry = 0; itry < nmax*n; ++itry) {
1051+
float gmax = 0;
1052+
int best_j = -1, dir = 0;
1053+
for (int j = 0; j < n; ++j) {
1054+
float g = x[j] - scale*L[j] - min;
1055+
if (g > 0 && L[j] < nmax && g > gmax) {
1056+
gmax = g; best_j = j; dir = 1;
1057+
}
1058+
else if (g < 0 && L[j] > 0 && -g > gmax) {
1059+
gmax = -g; best_j = j; dir = -1;
1060+
}
1061+
}
1062+
if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break;
1063+
double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]);
1064+
sum_l += w*dir;
1065+
sum_l2 += w*(2*L[best_j]*dir + 1);
1066+
sum_xl += w*(double)x[best_j]*dir;
1067+
double D = (double)sum_w * sum_l2 - sum_l * sum_l;
1068+
if (D <= 0) break;
1069+
double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D;
1070+
double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D;
1071+
if (this_min > 0) {
1072+
this_min = 0;
1073+
this_scale = sum_xl / sum_l2;
1074+
}
1075+
if (this_scale < 0) break;
1076+
double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l
1077+
- this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w;
1078+
if (score <= best) break;
1079+
best = score;
1080+
scale = this_scale;
1081+
min = this_min;
1082+
L[best_j] += dir;
1083+
last_j = best_j; last_dir = dir;
1084+
}
9701085
*the_min = -min;
9711086
return scale;
9721087
}
@@ -1048,7 +1163,7 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10481163
GGML_ASSERT(quant_weights);
10491164
assert(k % QK_K == 0);
10501165
const int nb = k / QK_K;
1051-
const bool requantize = true;
1166+
// const bool requantize = true;
10521167

10531168
uint8_t L[QK_K];
10541169
uint8_t Laux[16];
@@ -1062,39 +1177,33 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10621177
memset(sw, 0, QK_K/16*sizeof(float));
10631178
float sumx2 = 0;
10641179
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1065-
float sigma2 = sumx2/QK_K;
1180+
float sigma2 = 0.75f*sumx2/QK_K;
10661181
for (int j = 0; j < QK_K/16; ++j) {
10671182
const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;
10681183
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
10691184
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
10701185
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
10711186
}
10721187

1073-
float dm, mm;
1074-
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1075-
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1188+
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1189+
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
10761190

10771191
y[i].d = GGML_FP32_TO_FP16(dm);
10781192
y[i].dmin = GGML_FP32_TO_FP16(mm);
1079-
dm = GGML_FP16_TO_FP32(y[i].d);
1080-
mm = GGML_FP16_TO_FP32(y[i].dmin);
10811193

10821194
for (int j = 0; j < QK_K/16; ++j) {
1083-
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1084-
}
1085-
1086-
if (requantize) {
1087-
for (int j = 0; j < QK_K/16; ++j) {
1088-
const float d = dm * (y[i].scales[j] & 0xF);
1089-
if (!d) continue;
1090-
const float m = mm * (y[i].scales[j] >> 4);
1091-
for (int ii = 0; ii < 16; ++ii) {
1092-
int l = nearest_int((x[16*j + ii] + m)/d);
1093-
l = MAX(0, MIN(3, l));
1094-
L[16*j + ii] = l;
1095-
}
1195+
float d = dm*Ls[j];
1196+
float m = mm*Lm[j];
1197+
float id = d ? 1/d : 0.f;
1198+
for (int l = 0; l < QK_K/16; ++l) {
1199+
int q = nearest_int((x[16*j + l] + m)*id);
1200+
q = MAX(0, MIN(3, q));
1201+
L[16*j + l] = q;
10961202
}
10971203
}
1204+
for (int j = 0; j < QK_K/16; ++j) {
1205+
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1206+
}
10981207

10991208
for (int j = 0; j < QK_K; j += 128) {
11001209
for (int l = 0; l < 32; ++l) {
@@ -1983,8 +2092,12 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G
19832092
const int64_t nb = n_per_row/QK4_0;
19842093
for (int ib = 0; ib < nb; ++ib) {
19852094
const float * xb = x + QK4_0 * ib;
1986-
const float * qw = quant_weights + QK4_0 * ib;
1987-
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2095+
if (quant_weights) {
2096+
const float * qw = quant_weights + QK4_0 * ib;
2097+
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2098+
} else {
2099+
for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
2100+
}
19882101
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
19892102
y[ib].d = GGML_FP32_TO_FP16(d);
19902103
for (int j = 0; j < 16; ++j) {
@@ -4881,6 +4994,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48814994
}
48824995
d = sumqx/sumq2;
48834996
float best = d*sumqx;
4997+
float best_sumqx = sumqx, best_sumq2 = sumq2;
48844998
for (int itry = -ntry; itry <= ntry; ++itry) {
48854999
id = (itry + values[0])/max;
48865000
sumqx = sumq2 = 0;
@@ -4894,8 +5008,68 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48945008
}
48955009
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
48965010
d = sumqx/sumq2; best = d * sumqx;
5011+
best_sumqx = sumqx; best_sumq2 = sumq2;
5012+
for (int j = 0; j < block_size; ++j) {
5013+
float al = id*xb[j];
5014+
Lb[j] = best_index_iq4nl(values, al);
5015+
}
5016+
}
5017+
id = (itry + values[15])/max;
5018+
sumqx = sumq2 = 0;
5019+
for (int j = 0; j < block_size; ++j) {
5020+
float al = id*xb[j];
5021+
int l = best_index_iq4nl(values, al);
5022+
float q = values[l];
5023+
float w = weight[j];
5024+
sumqx += w*q*xb[j];
5025+
sumq2 += w*q*q;
5026+
}
5027+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
5028+
d = sumqx/sumq2; best = d * sumqx;
5029+
best_sumqx = sumqx; best_sumq2 = sumq2;
5030+
for (int j = 0; j < block_size; ++j) {
5031+
float al = id*xb[j];
5032+
Lb[j] = best_index_iq4nl(values, al);
5033+
}
5034+
}
5035+
}
5036+
sumqx = best_sumqx; sumq2 = best_sumq2;
5037+
best_sumqx = sumqx; best_sumq2 = sumq2;
5038+
for (int iter = 0; iter < 32*block_size; ++iter) {
5039+
float min_step = INFINITY;
5040+
int best_j = -1; int dir = 0;
5041+
for (int j = 0; j < block_size; ++j) {
5042+
float w = weight[j];
5043+
float g = d * w * (xb[j] - d*values[Lb[j]]);
5044+
if (g > 0 && Lb[j] < 15) {
5045+
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
5046+
if (step < min_step) {
5047+
min_step = step; best_j = j; dir = 1;
5048+
}
5049+
}
5050+
else if (g < 0 && Lb[j] > 0) {
5051+
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
5052+
if (step < min_step) {
5053+
min_step = step; best_j = j; dir = -1;
5054+
}
5055+
}
5056+
}
5057+
if (best_j < 0) break;
5058+
5059+
float new_sumqx = sumqx, new_sumq2 = sumq2;
5060+
float w = weight[best_j];
5061+
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
5062+
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
5063+
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
5064+
sumqx = new_sumqx; sumq2 = new_sumq2;
5065+
d = sumqx/sumq2; best = d*sumqx;
5066+
Lb[best_j] += dir;
5067+
}
5068+
else {
5069+
break;
48975070
}
48985071
}
5072+
48995073
scales[ib] = d;
49005074
float abs_d = fabsf(d);
49015075
if (abs_d > amax_scale) {

0 commit comments

Comments
 (0)