Skip to content

Commit

Permalink
Deduplicate q4 quantization functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sw committed Mar 22, 2023
1 parent ae44e23 commit b4dfdf7
Showing 1 changed file with 63 additions and 104 deletions.
167 changes: 63 additions & 104 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -400,16 +400,63 @@ static inline __m128i packNibbles( __m256i bytes )
// method 5
// blocks of QK elements
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)

// reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);
const int nb = k / QK;

const size_t bs = sizeof(float) + QK/2;

uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));

uint8_t pp[QK/2];

for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max

for (int l = 0; l < QK; l++) {
const float v = x[i*QK + l];
amax = MAX(amax, fabsf(v));
}

const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;

*(float *)pd = d;
pd += bs;

for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*QK + l + 0]*id;
const float v1 = x[i*QK + l + 1]*id;

const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;

assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);

pp[l/2] = vi0 | (vi1 << 4);
}

memcpy(pb, pp, sizeof(pp));
pb += bs;
}
}

void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % QK == 0);

#if __ARM_NEON || defined(__AVX2__) || defined(__wasm_simd128__)
const int nb = k / QK;
const size_t bs = sizeof(float) + QK/2;

uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
uint8_t * restrict pb = ((uint8_t *)y + 0*bs + sizeof(float));

uint8_t pp[QK/2];
#endif

#if __ARM_NEON
#if QK == 32
Expand Down Expand Up @@ -566,36 +613,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
#endif
#else
// scalar
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max

for (int l = 0; l < QK; l++) {
const float v = x[i*QK + l];
amax = MAX(amax, fabsf(v));
}

const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;

*(float *)pd = d;
pd += bs;

for (int l = 0; l < QK; l += 2) {
const float v0 = x[i*QK + l + 0]*id;
const float v1 = x[i*QK + l + 1]*id;

const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;

assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);

pp[l/2] = vi0 | (vi1 << 4);
}

memcpy(pb, pp, sizeof(pp));
pb += bs;
}
quantize_row_q4_0_reference(x, y, k);
#endif
}

Expand Down Expand Up @@ -10709,49 +10727,23 @@ size_t ggml_quantize_q4_0(float * src, void * dst, int n, int k, int qk, int64_t

assert(k % qk == 0);

const size_t pp_size = qk / 2;
uint8_t * pp = (uint8_t *) alloca(pp_size);

char * pdst = (char *) dst;

for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));

for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max

{
for (int l = 0; l < qk; l++) {
const float v = src[j + i*qk + l];
amax = MAX(amax, fabsf(v));
}

const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0f/d : 0.0f;

*(float *) pd = d;
pd += bs;
quantize_row_q4_0_reference(src + j, pd, k);

for (int l = 0; l < qk; l += 2) {
const float v0 = (src[j + i*qk + l + 0])*id;
const float v1 = (src[j + i*qk + l + 1])*id;

const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
const uint8_t vi1 = ((int8_t) (round(v1))) + 8;

assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);

hist[vi0]++;
hist[vi1]++;

pp[l/2] = vi0 | (vi1 << 4);
}
for (int i = 0; i < nb; i++) {
for (int l = 0; l < qk; l += 2) {
const uint8_t vi0 = pb[l/2] & 0xF;
const uint8_t vi1 = pb[l/2] >> 4;

memcpy(pb, pp, pp_size);
pb += bs;
hist[vi0]++;
hist[vi1]++;
}
pb += bs;
}
}

Expand All @@ -10765,56 +10757,23 @@ size_t ggml_quantize_q4_1(float * src, void * dst, int n, int k, int qk, int64_t

assert(k % qk == 0);

const size_t pp_size = qk / 2;
uint8_t * pp = (uint8_t *) alloca(pp_size);

char * pdst = (char *) dst;

for (int j = 0; j < n; j += k) {
uint8_t * pd = (uint8_t *) (pdst + (j/k)*row_size + 0*bs);
uint8_t * pm = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + sizeof(float));
uint8_t * pb = (uint8_t *) (pdst + (j/k)*row_size + 0*bs + 2*sizeof(float));

//printf("n = %d, k = %d, nb = %d, row_size = %d, j = %d, pm = %p, pd = %p, pb = %p\n", n, k, nb, row_size, j, pm, pd, pb);
quantize_row_q4_1(src + j, pd, k);

for (int i = 0; i < nb; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;

{
for (int l = 0; l < qk; l++) {
const float v = src[j + i*qk + l];
if (v < min) min = v;
if (v > max) max = v;
}

const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;

*(float *) pd = d;
*(float *) pm = min;
pd += bs;
pm += bs;

for (int l = 0; l < qk; l += 2) {
const float v0 = (src[j + i*qk + l + 0] - min)*id;
const float v1 = (src[j + i*qk + l + 1] - min)*id;

const uint8_t vi0 = round(v0);
const uint8_t vi1 = round(v1);

assert(vi0 >= 0 && vi0 < 16);
assert(vi1 >= 0 && vi1 < 16);

hist[vi0]++;
hist[vi1]++;

pp[l/2] = vi0 | (vi1 << 4);
}
for (int l = 0; l < qk; l += 2) {
const uint8_t vi0 = pb[l/2] & 0xF;
const uint8_t vi1 = pb[l/2] >> 4;

memcpy(pb, pp, pp_size);
pb += bs;
hist[vi0]++;
hist[vi1]++;
}
pb += bs;
}
}

Expand Down

0 comments on commit b4dfdf7

Please sign in to comment.