Skip to content

Commit c29bd84

Browse files
Nexesenexikawrakow
andcommitted
Quantization improvements #295 and #302, GGML part only
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent 4c9649a commit c29bd84

File tree

1 file changed

+220
-46
lines changed

1 file changed

+220
-46
lines changed

ggml/src/ggml-quants.c

Lines changed: 220 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,8 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
564564
float scale = suml2 ? sumlx/suml2 : 0.0f;
565565
if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale;
566566
float best = scale * sumlx;
567+
float best_sumlx = sumlx, best_suml2 = suml2;
567568
for (int is = -9; is <= 9; ++is) {
568-
if (is == 0) {
569-
continue;
570-
}
571569
iscale = -(nmax + 0.1f*is) / max;
572570
sumlx = suml2 = 0;
573571
for (int i = 0; i < n; ++i) {
@@ -583,7 +581,66 @@ static float make_qx_quants(int n, int nmax, const float * GGML_RESTRICT x, int8
583581
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
584582
}
585583
scale = sumlx/suml2; best = scale*sumlx;
584+
best_sumlx = sumlx; best_suml2 = suml2;
585+
}
586+
iscale = (nmax-1 + 0.1f*is) / max;
587+
sumlx = suml2 = 0;
588+
for (int i = 0; i < n; ++i) {
589+
int l = nearest_int(iscale * x[i]);
590+
l = MAX(-nmax, MIN(nmax-1, l));
591+
float w = qw ? qw[i] : rmse_type == 1 ? x[i] * x[i] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[i]) : sqrtf(fabsf(x[i]));
592+
sumlx += w*x[i]*l;
593+
suml2 += w*l*l;
586594
}
595+
if (suml2 > 0 && sumlx*sumlx > best*suml2) {
596+
for (int i = 0; i < n; ++i) {
597+
int l = nearest_int(iscale * x[i]);
598+
L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
599+
}
600+
scale = sumlx/suml2; best = scale*sumlx;
601+
best_sumlx = sumlx; best_suml2 = suml2;
602+
}
603+
}
604+
605+
sumlx = best_sumlx; suml2 = best_suml2;
606+
for (int iter = 0; iter < n*(2*nmax-1); ++iter) {
607+
float abs_gmax = 0, gmax = 0;
608+
int best_j = -1;
609+
for (int j = 0; j < n; ++j) {
610+
float w = qw ? qw[j] : rmse_type == 1 ? x[j] * x[j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[j]) : sqrtf(fabsf(x[j]));
611+
int l = L[j] - nmax;
612+
float g = scale * w * (x[j] - scale*l);
613+
if ((g > 0 && l < nmax-1) || (g < 0 && l > -nmax)) {
614+
float ag = fabsf(g);
615+
if (ag > abs_gmax) {
616+
abs_gmax = ag; gmax = g; best_j = j;
617+
}
618+
}
619+
}
620+
if (best_j < 0) break;
621+
622+
float new_sumlx = sumlx, new_suml2 = suml2;
623+
float w = qw ? qw[best_j] : rmse_type == 1 ? x[best_j] * x[best_j] : rmse_type == 2 ? 1 : rmse_type == 3 ? fabsf(x[best_j]) : sqrtf(fabsf(x[best_j]));
624+
int l = L[best_j] - nmax;
625+
if (gmax > 0) {
626+
new_sumlx += w*x[best_j];
627+
new_suml2 += w*(2*l + 1);
628+
l += 1;
629+
} else {
630+
new_sumlx -= w*x[best_j];
631+
new_suml2 -= w*(2*l - 1);
632+
l -= 1;
633+
}
634+
if (new_suml2 > 0 && new_sumlx*new_sumlx > best*new_suml2) {
635+
sumlx = new_sumlx; suml2 = new_suml2;
636+
scale = sumlx/suml2; best = scale*sumlx;
637+
L[best_j] = l + nmax;
638+
GGML_ASSERT(L[best_j] >= 0 && L[best_j] <= 2*nmax-1);
639+
}
640+
else {
641+
break;
642+
}
643+
587644
}
588645
return scale;
589646
}
@@ -889,8 +946,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
889946
float rmin, float rdelta, int nstep, bool use_mad) {
890947
float min = x[0];
891948
float max = x[0];
892-
float sum_w = weights ? weights[0] : x[0]*x[0];
893-
float sum_x = sum_w * x[0];
949+
double sum_w = weights ? (double)weights[0] : (double)(x[0]*x[0]);
950+
double sum_x = sum_w * (double)x[0];
951+
double sum_x2 = sum_w * (double)x[0] * (double)x[0];
894952
#ifdef HAVE_BUGGY_APPLE_LINKER
895953
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
896954
for (volatile int i = 1; i < n; ++i) {
@@ -900,8 +958,9 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
900958
if (x[i] < min) min = x[i];
901959
if (x[i] > max) max = x[i];
902960
float w = weights ? weights[i] : x[i]*x[i];
903-
sum_w += w;
904-
sum_x += w * x[i];
961+
sum_w += (double)w;
962+
sum_x += (double)w * (double)x[i];
963+
sum_x2 += (double)w * (double)x[i] * (double)x[i];
905964
}
906965
if (min > 0) {
907966
min = 0;
@@ -913,13 +972,13 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
913972
}
914973
float iscale = nmax/(max - min);
915974
float scale = 1/iscale;
916-
float best_mad = 0;
975+
double best_mad = 0;
917976
for (int i = 0; i < n; ++i) {
918977
int l = nearest_int(iscale*(x[i] - min));
919978
L[i] = MAX(0, MIN(nmax, l));
920-
float diff = scale * L[i] + min - x[i];
921-
diff = use_mad ? fabsf(diff) : diff*diff;
922-
float w = weights ? weights[i] : x[i]*x[i];
979+
double diff = (double)scale * L[i] + (double)min - (double)x[i];
980+
diff = use_mad ? fabs(diff) : diff*diff;
981+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
923982
best_mad += w * diff;
924983
}
925984
if (nstep < 1) {
@@ -928,30 +987,35 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
928987
}
929988
for (int is = 0; is <= nstep; ++is) {
930989
iscale = (rmin + rdelta*is + nmax)/(max - min);
931-
float sum_l = 0, sum_l2 = 0, sum_xl = 0;
990+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
932991
for (int i = 0; i < n; ++i) {
933992
int l = nearest_int(iscale*(x[i] - min));
934993
l = MAX(0, MIN(nmax, l));
935994
Laux[i] = l;
936995
float w = weights ? weights[i] : x[i]*x[i];
937-
sum_l += w*l;
938-
sum_l2 += w*l*l;
939-
sum_xl += w*l*x[i];
996+
sum_l += (double)w*l;
997+
sum_l2 += (double)w*l*l;
998+
sum_xl += (double)w*l*(double)x[i];
940999
}
941-
float D = sum_w * sum_l2 - sum_l * sum_l;
1000+
double D = sum_w * sum_l2 - sum_l * sum_l;
9421001
if (D > 0) {
943-
float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
944-
float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
1002+
double this_scale = (sum_w * sum_xl - sum_x * sum_l)/D;
1003+
double this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D;
9451004
if (this_min > 0) {
9461005
this_min = 0;
9471006
this_scale = sum_xl / sum_l2;
9481007
}
949-
float mad = 0;
950-
for (int i = 0; i < n; ++i) {
951-
float diff = this_scale * Laux[i] + this_min - x[i];
952-
diff = use_mad ? fabsf(diff) : diff*diff;
953-
float w = weights ? weights[i] : x[i]*x[i];
954-
mad += w * diff;
1008+
double mad = 0;
1009+
if (use_mad) {
1010+
for (int i = 0; i < n; ++i) {
1011+
double diff = (double)this_scale * Laux[i] + (double)this_min - (double)x[i];
1012+
diff = fabs(diff);
1013+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
1014+
mad += w * diff;
1015+
}
1016+
} else {
1017+
mad = sum_x2 - 2*this_scale*sum_xl - 2*this_min*sum_x + 2*this_scale*this_min*sum_l
1018+
+ this_scale*this_scale*sum_l2 + this_min*this_min*sum_w;
9551019
}
9561020
if (mad < best_mad) {
9571021
for (int i = 0; i < n; ++i) {
@@ -963,6 +1027,57 @@ static float make_qkx3_quants(int n, int nmax, const float * GGML_RESTRICT x, co
9631027
}
9641028
}
9651029
}
1030+
if (use_mad) {
1031+
*the_min = -min;
1032+
return scale;
1033+
}
1034+
1035+
double sum_l = 0, sum_l2 = 0, sum_xl = 0;
1036+
for (int i = 0; i < n; ++i) {
1037+
int l = L[i];
1038+
double w = weights ? (double)weights[i] : (double)(x[i]*x[i]);
1039+
sum_l += w*l;
1040+
sum_l2 += w*l*l;
1041+
sum_xl += w*l*(double)x[i];
1042+
}
1043+
double best = 2*(double)scale*sum_xl + 2*(double)min*sum_x - 2*(double)scale*(double)min*sum_l
1044+
- (double)scale*(double)scale*sum_l2 - (double)min*(double)min*sum_w;
1045+
int last_j = -1, last_dir = 0;
1046+
for (int itry = 0; itry < nmax*n; ++itry) {
1047+
float gmax = 0;
1048+
int best_j = -1, dir = 0;
1049+
for (int j = 0; j < n; ++j) {
1050+
float g = x[j] - scale*L[j] - min;
1051+
if (g > 0 && L[j] < nmax && g > gmax) {
1052+
gmax = g; best_j = j; dir = 1;
1053+
}
1054+
else if (g < 0 && L[j] > 0 && -g > gmax) {
1055+
gmax = -g; best_j = j; dir = -1;
1056+
}
1057+
}
1058+
if (best_j < 0 || (best_j == last_j && dir == -last_dir)) break;
1059+
double w = weights ? (double)weights[best_j] : (double)(x[best_j]*x[best_j]);
1060+
sum_l += w*dir;
1061+
sum_l2 += w*(2*L[best_j]*dir + 1);
1062+
sum_xl += w*(double)x[best_j]*dir;
1063+
double D = (double)sum_w * sum_l2 - sum_l * sum_l;
1064+
if (D <= 0) break;
1065+
double this_scale = ((double)sum_w * sum_xl - (double)sum_x * sum_l)/D;
1066+
double this_min = (sum_l2 * (double)sum_x - sum_l * sum_xl)/D;
1067+
if (this_min > 0) {
1068+
this_min = 0;
1069+
this_scale = sum_xl / sum_l2;
1070+
}
1071+
if (this_scale < 0) break;
1072+
double score = 2*this_scale*sum_xl + 2*this_min*(double)sum_x - 2*this_scale*this_min*sum_l
1073+
- this_scale*this_scale*sum_l2 - this_min*this_min*(double)sum_w;
1074+
if (score <= best) break;
1075+
best = score;
1076+
scale = this_scale;
1077+
min = this_min;
1078+
L[best_j] += dir;
1079+
last_j = best_j; last_dir = dir;
1080+
}
9661081
*the_min = -min;
9671082
return scale;
9681083
}
@@ -1044,7 +1159,7 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10441159
GGML_ASSERT(quant_weights);
10451160
assert(k % QK_K == 0);
10461161
const int nb = k / QK_K;
1047-
const bool requantize = true;
1162+
// const bool requantize = true;
10481163

10491164
uint8_t L[QK_K];
10501165
uint8_t Laux[16];
@@ -1058,39 +1173,33 @@ static void quantize_row_q2_K_impl(const float * GGML_RESTRICT x, block_q2_K * G
10581173
memset(sw, 0, QK_K/16*sizeof(float));
10591174
float sumx2 = 0;
10601175
for (int j = 0; j < QK_K; ++j) sumx2 += x[j]*x[j];
1061-
float sigma2 = sumx2/QK_K;
1176+
float sigma2 = 0.75f*sumx2/QK_K;
10621177
for (int j = 0; j < QK_K/16; ++j) {
10631178
const float * GGML_RESTRICT qw = quant_weights + QK_K * i + 16*j;
10641179
for (int l = 0; l < 16; ++l) weight[l] = qw[l] * sqrtf(sigma2 + x[16*j + l]*x[16*j + l]);
10651180
for (int l = 0; l < QK_K/16; ++l) sw[j] += weight[l];
10661181
scales[j] = make_qkx3_quants(16, 3, x + 16*j, weight, L + 16*j, &mins[j], Laux, -0.9f, 0.05f, 36, false);
10671182
}
10681183

1069-
float dm, mm;
1070-
dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1071-
mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
1184+
float dm = make_qp_quants(QK_K/16, 15, scales, Ls, sw);
1185+
float mm = make_qp_quants(QK_K/16, 15, mins, Lm, sw);
10721186

10731187
y[i].d = GGML_FP32_TO_FP16(dm);
10741188
y[i].dmin = GGML_FP32_TO_FP16(mm);
1075-
dm = GGML_FP16_TO_FP32(y[i].d);
1076-
mm = GGML_FP16_TO_FP32(y[i].dmin);
10771189

10781190
for (int j = 0; j < QK_K/16; ++j) {
1079-
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1080-
}
1081-
1082-
if (requantize) {
1083-
for (int j = 0; j < QK_K/16; ++j) {
1084-
const float d = dm * (y[i].scales[j] & 0xF);
1085-
if (!d) continue;
1086-
const float m = mm * (y[i].scales[j] >> 4);
1087-
for (int ii = 0; ii < 16; ++ii) {
1088-
int l = nearest_int((x[16*j + ii] + m)/d);
1089-
l = MAX(0, MIN(3, l));
1090-
L[16*j + ii] = l;
1091-
}
1191+
float d = dm*Ls[j];
1192+
float m = mm*Lm[j];
1193+
float id = d ? 1/d : 0.f;
1194+
for (int l = 0; l < QK_K/16; ++l) {
1195+
int q = nearest_int((x[16*j + l] + m)*id);
1196+
q = MAX(0, MIN(3, q));
1197+
L[16*j + l] = q;
10921198
}
10931199
}
1200+
for (int j = 0; j < QK_K/16; ++j) {
1201+
y[i].scales[j] = Ls[j] | (Lm[j] << 4);
1202+
}
10941203

10951204
for (int j = 0; j < QK_K; j += 128) {
10961205
for (int l = 0; l < 32; ++l) {
@@ -1979,8 +2088,12 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G
19792088
const int64_t nb = n_per_row/QK4_0;
19802089
for (int ib = 0; ib < nb; ++ib) {
19812090
const float * xb = x + QK4_0 * ib;
1982-
const float * qw = quant_weights + QK4_0 * ib;
1983-
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2091+
if (quant_weights) {
2092+
const float * qw = quant_weights + QK4_0 * ib;
2093+
for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]);
2094+
} else {
2095+
for (int j = 0; j < QK4_0; ++j) weight[j] = xb[j]*xb[j];
2096+
}
19842097
float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight);
19852098
y[ib].d = GGML_FP32_TO_FP16(d);
19862099
for (int j = 0; j < 16; ++j) {
@@ -4878,6 +4991,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48784991
}
48794992
d = sumqx/sumq2;
48804993
float best = d*sumqx;
4994+
float best_sumqx = sumqx, best_sumq2 = sumq2;
48814995
for (int itry = -ntry; itry <= ntry; ++itry) {
48824996
id = (itry + values[0])/max;
48834997
sumqx = sumq2 = 0;
@@ -4891,8 +5005,68 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block
48915005
}
48925006
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
48935007
d = sumqx/sumq2; best = d * sumqx;
5008+
best_sumqx = sumqx; best_sumq2 = sumq2;
5009+
for (int j = 0; j < block_size; ++j) {
5010+
float al = id*xb[j];
5011+
Lb[j] = best_index_iq4nl(values, al);
5012+
}
5013+
}
5014+
id = (itry + values[15])/max;
5015+
sumqx = sumq2 = 0;
5016+
for (int j = 0; j < block_size; ++j) {
5017+
float al = id*xb[j];
5018+
int l = best_index_iq4nl(values, al);
5019+
float q = values[l];
5020+
float w = weight[j];
5021+
sumqx += w*q*xb[j];
5022+
sumq2 += w*q*q;
5023+
}
5024+
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
5025+
d = sumqx/sumq2; best = d * sumqx;
5026+
best_sumqx = sumqx; best_sumq2 = sumq2;
5027+
for (int j = 0; j < block_size; ++j) {
5028+
float al = id*xb[j];
5029+
Lb[j] = best_index_iq4nl(values, al);
5030+
}
5031+
}
5032+
}
5033+
sumqx = best_sumqx; sumq2 = best_sumq2;
5034+
best_sumqx = sumqx; best_sumq2 = sumq2;
5035+
for (int iter = 0; iter < 32*block_size; ++iter) {
5036+
float min_step = INFINITY;
5037+
int best_j = -1; int dir = 0;
5038+
for (int j = 0; j < block_size; ++j) {
5039+
float w = weight[j];
5040+
float g = d * w * (xb[j] - d*values[Lb[j]]);
5041+
if (g > 0 && Lb[j] < 15) {
5042+
float step = (values[Lb[j]+1] - values[Lb[j]])/g;
5043+
if (step < min_step) {
5044+
min_step = step; best_j = j; dir = 1;
5045+
}
5046+
}
5047+
else if (g < 0 && Lb[j] > 0) {
5048+
float step = (values[Lb[j]-1] - values[Lb[j]])/g;
5049+
if (step < min_step) {
5050+
min_step = step; best_j = j; dir = -1;
5051+
}
5052+
}
5053+
}
5054+
if (best_j < 0) break;
5055+
5056+
float new_sumqx = sumqx, new_sumq2 = sumq2;
5057+
float w = weight[best_j];
5058+
new_sumqx += w*xb[best_j]*(values[Lb[best_j]+dir] - values[Lb[best_j]]);
5059+
new_sumq2 += w*(values[Lb[best_j]+dir]*values[Lb[best_j]+dir] - values[Lb[best_j]]*values[Lb[best_j]]);
5060+
if (new_sumq2 > 0 && new_sumqx*new_sumqx > best*new_sumq2) {
5061+
sumqx = new_sumqx; sumq2 = new_sumq2;
5062+
d = sumqx/sumq2; best = d*sumqx;
5063+
Lb[best_j] += dir;
5064+
}
5065+
else {
5066+
break;
48945067
}
48955068
}
5069+
48965070
scales[ib] = d;
48975071
float abs_d = fabsf(d);
48985072
if (abs_d > amax_scale) {

0 commit comments

Comments
 (0)