Skip to content

Commit 76b97c8

Browse files
ikawrakowIwan Kawrakow
andauthored
Adding IQ4_KSS: 4.0 bpw quants (#89)
* iq4_kss: WIP * iq4_kss: CUDA dequantize works So we can run perplexity. Sadly, the result does not look good on the bpw vs quantization error plot. * iq4_kss: slightly better quantization * iq4_kss: another small quantization improvement * iq4_kss: CUDA works TG-128 performance is very decent with 131 t/s for LLaMA-3.1-8B. In comparison, we have 123 t/s for q4_0 and 128 t/s for iq4_ks. I.e., the reduced model size more than offsets the additional bit fiddling required for iq4_kss. * iq4_kss: new bit arrangement - CUDA and Zen4 work Did not lose performance on CUDA. Zen4 is decent, but not great: PP-512(LLaMA-3.1-8B) = 163 t/s. TG-128 is of course better than other 4-bit quants due to smaller model size. We get 14.5 t/s @ 8 threads. * iq4_kss: ARM_NEON. Predictably very slow * iq4_kss: Metal PP is not too bad - just 10% slower than q4_0. But TG is 30% slower, i.e., predictably bad. * iq4_kss: somewhat faster Metal dot product 45.75 t/s -> 48.75 t/s. Still 22% slower than q4_0 * iq4_kss: AVX2 Bad, but better than I expected. PP-512(LLaMA-3.1-8B) = 167 t/s on the Ryzen-5950X. I.e., with 32 AVX2 threads we get the performance of 16 Zen4 threads. * iq4_kss: very slightly faster Metal dot product 48.7 t/s -> 49.3 t/s --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 993ca95 commit 76b97c8

File tree

19 files changed

+997
-25
lines changed

19 files changed

+997
-25
lines changed

examples/quantize-stats/quantize-stats.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
256256
float mse0 = 0, mse = 0;
257257
auto compute = [&mutex, &counter, &mse0, &mse, values, row_size, nblock, nrows, n_per_row, chunk] () {
258258
std::vector<char> Q(row_size);
259+
float diff[4];
260+
float xv[4];
259261
float lmse0 = 0, lmse = 0;
260262
while (true) {
261263
std::unique_lock<std::mutex> lock(mutex);
@@ -282,25 +284,41 @@ static void analyze_iq4ks(const char * name, int nrows, int n_per_row, const flo
282284
for (int j = 0; j < 16; j += 2) {
283285
uint16_t v0 = *(const uint16_t *)(qs + j);
284286
int non = popcount(v0);
285-
float diff1 = xb[j+ 0] - dl*values[qs[j+0] & 0xf];
286-
float diff2 = xb[j+16] - dl*values[qs[j+0] >> 4];
287-
float diff3 = xb[j+ 1] - dl*values[qs[j+1] & 0xf];
288-
float diff4 = xb[j+17] - dl*values[qs[j+1] >> 4];
289-
lmse0 += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
287+
xv[0] = xb[j+ 0]; xv[1] = xb[j+16]; xv[2] = xb[j+ 1]; xv[3] = xb[j+17];
288+
diff[0] = xv[0] - dl*values[qs[j+0] & 0xf];
289+
diff[1] = xv[1] - dl*values[qs[j+0] >> 4];
290+
diff[2] = xv[2] - dl*values[qs[j+1] & 0xf];
291+
diff[3] = xv[3] - dl*values[qs[j+1] >> 4];
292+
float diff4 = diff[0]*diff[0] + diff[1]*diff[1] + diff[2]*diff[2] + diff[3]*diff[3];
293+
lmse0 += diff4;
290294
if (non%2 == 0) {
291-
lmse += diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
295+
lmse += diff4;
292296
} else {
293297
float best = std::numeric_limits<float>::max();
294-
for (int k = 0; k < 16; k += 4) {
295-
uint16_t v = v0 ^ (1 << k);
296-
uint8_t v1 = v;
297-
uint8_t v2 = v >> 8;
298-
diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
299-
diff2 = xb[j+16] - dl*values[v1 >> 4];
300-
diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
301-
diff4 = xb[j+17] - dl*values[v2 >> 4];
302-
float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
303-
if (score < best) best = score;
298+
//for (int k = 0; k < 16; k += 4) {
299+
// uint16_t v = v0 ^ (1 << k);
300+
// uint8_t v1 = v;
301+
// uint8_t v2 = v >> 8;
302+
// diff1 = xb[j+ 0] - dl*values[v1 & 0xf];
303+
// diff2 = xb[j+16] - dl*values[v1 >> 4];
304+
// diff3 = xb[j+ 1] - dl*values[v2 & 0xf];
305+
// diff4 = xb[j+17] - dl*values[v2 >> 4];
306+
// float score = diff1*diff1 + diff2*diff2 + diff3*diff3 + diff4*diff4;
307+
// if (score < best) best = score;
308+
//}
309+
for (int k = 0; k < 4; ++k) {
310+
uint16_t v = (v0 >> 4*k) & 0xf;
311+
auto pc = popcount(v);
312+
if (v > 0 && popcount(v-1u) != pc) {
313+
float this_diff = xv[k] - dl*values[v-1u];
314+
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
315+
if (score < best) best = score;
316+
}
317+
if (v < 15 && popcount(v + 1u) != pc) {
318+
float this_diff = xv[k] - dl*values[v+1u];
319+
float score = diff4 - diff[k]*diff[k] + this_diff*this_diff;
320+
if (score < best) best = score;
321+
}
304322
}
305323
lmse += best;
306324
}

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
4444
{ "IQ4_NL", LLAMA_FTYPE_MOSTLY_IQ4_NL, " 4.50 bpw non-linear quantization", },
4545
{ "IQ4_XS", LLAMA_FTYPE_MOSTLY_IQ4_XS, " 4.25 bpw non-linear quantization", },
4646
{ "IQ4_KS", LLAMA_FTYPE_MOSTLY_IQ4_KS, " 4.25 bpw non-linear quantization", },
47+
{ "IQ4_KSS", LLAMA_FTYPE_MOSTLY_IQ4_KSS, " 4.0 bpw non-linear quantization", },
4748
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
4849
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
4950
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },

ggml/include/ggml.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ extern "C" {
405405
GGML_TYPE_IQ1_TN = 143,
406406
GGML_TYPE_IQ4_KS = 144,
407407
GGML_TYPE_IQ2_KS = 145,
408+
GGML_TYPE_IQ4_KSS = 146,
408409
GGML_TYPE_COUNT,
409410
};
410411

@@ -462,6 +463,7 @@ extern "C" {
462463
GGML_FTYPE_MOSTLY_IQ1_TN = 136, // except 1d tensors
463464
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
464465
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
466+
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
465467
};
466468

467469
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,11 @@ typedef struct {
447447
} block_iq4_ks;
448448
static_assert(sizeof(block_iq4_ks) == QK_K/32 + QK_K/2, "wrong iq4_ks block size/padding");
449449

450+
typedef struct {
451+
uint32_t qs[QK_K/8];
452+
} block_iq4_kss;
453+
static_assert(sizeof(block_iq4_kss) == QK_K/8*sizeof(uint32_t), "wrong iq4_kss block size/padding");
454+
450455
typedef struct {
451456
ggml_half d;
452457
uint16_t extra;

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2829,6 +2829,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28292829
case GGML_TYPE_IQ4_NL:
28302830
case GGML_TYPE_IQ4_XS:
28312831
case GGML_TYPE_IQ4_KS:
2832+
case GGML_TYPE_IQ4_KSS:
28322833
case GGML_TYPE_IQ2_K:
28332834
case GGML_TYPE_IQ2_KS:
28342835
case GGML_TYPE_IQ3_K:

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KS> {
543543
static constexpr int qi = QI4_XS;
544544
};
545545

546+
template<>
547+
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_KSS> {
548+
static constexpr int qk = QK_K;
549+
static constexpr int qr = QR4_XS;
550+
static constexpr int qi = QI4_XS;
551+
};
552+
546553
template<>
547554
struct ggml_cuda_type_traits<GGML_TYPE_IQ5_K> {
548555
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,37 @@ static __global__ void dequantize_block_iq4_ks(const void * __restrict__ vx, dst
638638
}
639639
}
640640

641+
template<typename dst_t>
642+
static __global__ void dequantize_block_iq4_kss(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
643+
644+
int64_t ii = blockIdx.x;
645+
int64_t row = (QK_K * ii) / n_per_row;
646+
const char * cx = (const char *)vx + row * row_size;
647+
float scale = *(const float *)cx;
648+
const block_iq4_kss * x = (const block_iq4_kss *)(cx + sizeof(float));
649+
const int64_t i = ii - (row*n_per_row)/QK_K;
650+
651+
const int64_t tid = threadIdx.x;
652+
const int64_t il = tid/8; // 0...3
653+
const int64_t ib = tid%8; // 0...7
654+
dst_t * y = yy + ii*QK_K + 32*ib + 4*il;
655+
const uint32_t * q4 = x[i].qs + 4*ib;
656+
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
657+
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
658+
const float d = scale * ((ls & 254) - 127);
659+
const int8_t * values = iq4k_values + ((ls & 1) << 4);
660+
uint32_t aux32[2];
661+
aux32[0] = q4[il] & 0xfffefffe;
662+
aux32[0] ^= (aux32[0] >> 1);
663+
aux32[1] = ((aux32[0] >> 4) & 0x0f0f0f0f);
664+
aux32[0] &= 0x0f0f0f0f;
665+
const uint8_t * aux8 = (const uint8_t *)aux32;
666+
for (int j = 0; j < 4; ++j) {
667+
y[j+ 0] = d * values[aux8[j+0]];
668+
y[j+16] = d * values[aux8[j+4]];
669+
}
670+
}
671+
641672
template<typename dst_t>
642673
static __global__ void dequantize_block_iq4_k(const void * __restrict__ vx, dst_t * __restrict__ yy) {
643674
const int64_t i = blockIdx.x;
@@ -980,6 +1011,14 @@ static void dequantize_row_iq4_ks_cuda(const void * vx, dst_t * y, const int64_t
9801011
dequantize_block_iq4_ks<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
9811012
}
9821013

1014+
template<typename dst_t>
1015+
static void dequantize_row_iq4_kss_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1016+
const int64_t k = nrows * n_per_row;
1017+
const int64_t row_size = ggml_row_size(GGML_TYPE_IQ4_KSS, n_per_row);
1018+
const int nb = (k + QK_K - 1) / QK_K;
1019+
dequantize_block_iq4_kss<<<nb, 32, 0, stream>>>(vx, y, n_per_row, row_size);
1020+
}
1021+
9831022
template<typename dst_t>
9841023
static void dequantize_row_iq2_ks_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
9851024
const int64_t k = nrows * n_per_row;
@@ -1152,6 +1191,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
11521191
return dequantize_row_iq4_xs_cuda;
11531192
case GGML_TYPE_IQ4_KS:
11541193
return dequantize_row_iq4_ks_cuda;
1194+
case GGML_TYPE_IQ4_KSS:
1195+
return dequantize_row_iq4_kss_cuda;
11551196
case GGML_TYPE_IQ2_KS:
11561197
return dequantize_row_iq2_ks_cuda;
11571198
case GGML_TYPE_IQ2_K:
@@ -1225,6 +1266,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
12251266
return dequantize_row_iq4_xs_cuda;
12261267
case GGML_TYPE_IQ4_KS:
12271268
return dequantize_row_iq4_ks_cuda;
1269+
case GGML_TYPE_IQ4_KSS:
1270+
return dequantize_row_iq4_kss_cuda;
12281271
case GGML_TYPE_IQ2_KS:
12291272
return dequantize_row_iq2_ks_cuda;
12301273
case GGML_TYPE_IQ2_K:

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,35 @@ __device__ __forceinline__ float vec_dot_iq4_ks_q8_1(
239239
return dl * __low2float(bq8_1[ib32].ds) * sumi;
240240
}
241241

242+
#define VDR_IQ4_KSS_Q8_1_MMVQ 4
243+
#define VDR_IQ4_KSS_Q8_1_MMQ 4
244+
245+
__device__ __forceinline__ float vec_dot_iq4_kss_q8_1(
246+
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
247+
248+
float scale = *(const float *)vbq;
249+
const block_iq4_kss * bq4 = (const block_iq4_kss *)((const char *)vbq + sizeof(float)) + kbx;
250+
const uint8_t * all_values = (const uint8_t *)iq4k_values;
251+
252+
// iqs is 0...28
253+
const int ib32 = iqs/4; // Why iqs/4 ?
254+
const int32_t * q8 = (const int *)bq8_1[ib32].qs;
255+
const uint32_t * q4 = (const uint32_t *)bq4->qs + 4*ib32;
256+
uint32_t s32 = (q4[0] & 0x00010001) | ((q4[1] & 0x00010001) << 2) | ((q4[2] & 0x00010001) << 4) | ((q4[3] & 0x00010001) << 6);
257+
uint8_t ls = (s32 | (s32 >> 15)) & 0xff;
258+
const float dl = scale * ((ls & 254) - 127);
259+
int v1, v2;
260+
int sumi = 0;
261+
for (int j = 0; j < 4; ++j) {
262+
uint32_t aux32 = q4[j] & 0xfffefffe;
263+
aux32 ^= (aux32 >> 1);
264+
get_int_from_table_16_shift(aux32, ls & 1, all_values, v1, v2);
265+
sumi = ggml_cuda_dp4a(v1, q8[j+0], sumi);
266+
sumi = ggml_cuda_dp4a(v2, q8[j+4], sumi);
267+
}
268+
return dl * __low2float(bq8_1[ib32].ds) * sumi;
269+
}
270+
242271
#define VDR_IQ5_K_Q8_1_MMVQ 4
243272
#define VDR_IQ5_K_Q8_1_MMQ 4
244273

@@ -703,6 +732,13 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
703732
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KS, VDR_IQ4_KS_Q8_1_MMVQ, vec_dot_iq4_ks_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
704733
}
705734

735+
void mul_mat_vec_iq4_kss_q8_1_cuda(
736+
const void * vx, const void * vy, float * dst,
737+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
738+
739+
iqk_mul_mat_vec_q_cuda<GGML_TYPE_IQ4_KSS, VDR_IQ4_KSS_Q8_1_MMVQ, vec_dot_iq4_kss_q8_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
740+
}
741+
706742
void mul_mat_vec_iq2_ks_q8_1_cuda(
707743
const void * vx, const void * vy, float * dst,
708744
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {

ggml/src/ggml-cuda/iqk_mmvq.cuh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ void mul_mat_vec_iq4_ks_q8_1_cuda(
3232
const void * vx, const void * vy, float * dst,
3333
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
3434

35+
void mul_mat_vec_iq4_kss_q8_1_cuda(
36+
const void * vx, const void * vy, float * dst,
37+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);
38+
3539
void mul_mat_vec_iq2_ks_q8_1_cuda(
3640
const void * vx, const void * vy, float * dst,
3741
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ void ggml_cuda_op_mul_mat_vec_q(
462462
case GGML_TYPE_IQ4_KS:
463463
mul_mat_vec_iq4_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
464464
break;
465+
case GGML_TYPE_IQ4_KSS:
466+
mul_mat_vec_iq4_kss_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
467+
break;
465468
case GGML_TYPE_IQ2_KS:
466469
mul_mat_vec_iq2_ks_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
467470
break;

0 commit comments

Comments
 (0)