Skip to content

Commit cedc9a6

Browse files
committed
start to implement IK quants
CUDA: faster float -> iq4_nl conversion (#73) * iqk_mul_mat: better iq4_nl implementation on Zen4/AVX2 PP-512 performance for LLaMA-3.1-8B goes to 162.6 t/s up from 133.2 t/s. * Speed up float -> iq4_nl conversion on CUDA --------- iq4_nl: faster quantization (#76) Enable IQ4_NL for V-cache in token generation Add IQ4_NL + IQ4_NL to FA This is a better alternative than Q4_0 + Q4_0 for the VRAM poor. IQ4_NL KVQ for KCPP/Croco missing templates instances for KVQ IQ4_NL Update fattn.cu for KVQ IQ4_NL Update fattn-vec-f16.cuh for KVQ IQ4_NL Update fattn-vec-f32.cuh for KVQ IQ4_NL CML and Makefile FOR IQ4_NL KV_IQ4_NL uncommenting VEC16 cases KV_IQ4_NL uncommenting VEC32 cases Adding Q6_0 (#77) * Adding q6_0 - basics + AVX2/Zen4 working * Adding q6_0: CUDA dequantize works, but not mmvq * Adding q6_0: CUDA mmvq works * Adding q6_0: CUDA cpy, so Q6_0 can be used for KV-cache * Add q6_0 to CPU flash attention Disappointing result: for LlaMA-3.2-1B, q6_0 K- and V-cache gives about the same PPL as q8_0 K-cache and q4_0 V-cache, while needing the exact same RAM. I.e., what was the point? * q6_0: slightly better kv-cache result Better than q8_0+q4_0, but not as good as q8_0+iq4_nl * q6_0: works on ARM_NEON * q6_0: dequantize works on Metal, but not vector dot product * q6_0: it now works on Metal Outperforms q5_0 by a significant margin. E.g. | model | size | params | backend | ngl | threads | test | t/s | | ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | ------------: | ---------------: | | llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 44.02 ± 0.08 | | llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | tg128 | 40.13 ± 0.12 | | llama 8B Q6_0 | 6.08 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 500.55 ± 0.32 | | llama 8B Q5_0 | 5.21 GiB | 8.03 B | Metal | 100 | 4 | pp512 | 448.02 ± 0.27 | * q6_0: can now be used for kv-cache on Metal --------- Enable q6_0 for flash attention As with IQ4_NL, just for head size of 128 for now. Without GGML_CUDA_FA_ALL_QUANTS set, only Q6_0 + Q5_0 and Q8_0 + Q6_0 are included. With this the VRAM poor have better options for selecting the best possible (as allowed by VRAM, model size, context length) quantized KV-cache. PR by Ikawrakow on ik_llama.cpp
1 parent 0cd1711 commit cedc9a6

File tree

52 files changed

+9283
-40
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+9283
-40
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2121
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
2222
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
2323
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
24+
{ "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", },
2425
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
2526
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
2627
{ "IQ2_S", LLAMA_FTYPE_MOSTLY_IQ2_S, " 2.5 bpw quantization", },

ggml/include/ggml.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,10 +385,15 @@ extern "C" {
385385
// GGML_TYPE_Q4_0_8_8 = 33,
386386
GGML_TYPE_TQ1_0 = 34,
387387
GGML_TYPE_TQ2_0 = 35,
388+
388389
// GGML_TYPE_IQ4_NL_4_4 = 36,
389390
// GGML_TYPE_IQ4_NL_4_8 = 37,
390391
// GGML_TYPE_IQ4_NL_8_8 = 38,
391-
GGML_TYPE_COUNT = 39,
392+
// GGML_TYPE_COUNT = 39,
393+
394+
//
395+
GGML_TYPE_Q6_0 = 133,
396+
GGML_TYPE_COUNT,
392397
};
393398

394399
// precision
@@ -423,6 +428,13 @@ extern "C" {
423428
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
424429
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
425430
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
431+
432+
// GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
433+
// GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
434+
// GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
435+
//
436+
GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors
437+
426438
};
427439

428440
// available tensor operations:

ggml/src/ggml-common.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ typedef sycl::half2 ggml_half2;
105105
#define QI5_1 (QK5_1 / (4 * QR5_1))
106106
#define QR5_1 2
107107

108+
#define QI6_0 (QK6_0 / (4 * QR6_0))
109+
#define QR6_0 2
110+
108111
#define QI8_0 (QK8_0 / (4 * QR8_0))
109112
#define QR8_0 1
110113

@@ -200,6 +203,14 @@ typedef struct {
200203
} block_q5_1;
201204
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_half) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
202205

206+
#define QK6_0 32
207+
typedef struct {
208+
ggml_half d; // delta
209+
uint8_t qh[QK6_0/4]; // 5+6-th bit of quants
210+
uint8_t qs[QK6_0/2]; // nibbles / quants
211+
} block_q6_0;
212+
static_assert(sizeof(block_q6_0) == sizeof(ggml_half) + QK6_0/2 + QK6_0/4, "wrong q6_0 block size/padding");
213+
203214
#define QK8_0 32
204215
typedef struct {
205216
ggml_half d; // delta

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
487487
static constexpr int qi = QI5_1;
488488
};
489489

490+
template<>
491+
struct ggml_cuda_type_traits<GGML_TYPE_Q6_0> {
492+
static constexpr int qk = QK6_0;
493+
static constexpr int qr = QR6_0;
494+
static constexpr int qi = QI6_0;
495+
};
496+
490497
template<>
491498
struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
492499
static constexpr int qk = QK8_0;

ggml/src/ggml-cuda/convert.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,36 @@ static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t
122122
}
123123
}
124124

125+
template<typename dst_t>
126+
static __global__ void dequantize_block_q6_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
127+
128+
const int64_t i = blockIdx.x;
129+
130+
// assume 32 threads
131+
const int64_t tid = threadIdx.x;
132+
const int64_t il = tid/8;
133+
const int64_t ir = tid%8;
134+
const int64_t ib = 8*i + ir;
135+
if (ib >= nb32) {
136+
return;
137+
}
138+
139+
dst_t * y = yy + 256*i + 32*ir + 4*il;
140+
141+
const block_q6_0 * x = (const block_q6_0 *)vx + ib;
142+
const float d = __half2float(x->d);
143+
const float dm = -32*d;
144+
145+
const uint8_t * qs = x->qs + 4*il;
146+
const uint8_t * qh = x->qh + 4*(il%2);
147+
148+
for (int l = 0; l < 4; ++l) {
149+
const uint8_t h = qh[l] >> 4*(il/2);
150+
y[l+ 0] = d * ((qs[l] & 0xF) | ((h << 4) & 0x30)) + dm;
151+
y[l+16] = d * ((qs[l] >> 4) | ((h << 2) & 0x30)) + dm;
152+
}
153+
}
154+
125155
//================================== k-quants
126156

127157
template<typename dst_t>
@@ -497,6 +527,13 @@ static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k
497527
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
498528
}
499529

530+
template<typename dst_t>
531+
static void dequantize_row_q6_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
532+
const int nb32 = k / 32;
533+
const int nb = (k + 255) / 256;
534+
dequantize_block_q6_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
535+
}
536+
500537
template<typename dst_t>
501538
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
502539
const int nb = k / QK_K;
@@ -598,6 +635,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
598635
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
599636
case GGML_TYPE_Q5_1:
600637
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
638+
case GGML_TYPE_Q6_0:
639+
return dequantize_row_q6_0_cuda;
601640
case GGML_TYPE_Q8_0:
602641
if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
603642
return dequantize_block_q8_0_f16_cuda;
@@ -648,6 +687,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
648687
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
649688
case GGML_TYPE_Q5_1:
650689
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
690+
case GGML_TYPE_Q6_0:
691+
return dequantize_row_q6_0_cuda;
651692
case GGML_TYPE_Q8_0:
652693
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
653694
case GGML_TYPE_Q2_K:

ggml/src/ggml-cuda/cpy.cu

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,59 @@ static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val,
251251
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
252252
}
253253

254+
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
255+
const float * xi = (const float *) cxi;
256+
block_q6_0 * dsti = (block_q6_0 *) cdsti;
257+
258+
float amax = 0.0f;
259+
float vmax = 0.0f;
260+
261+
for (int j = 0; j < QK6_0; ++j) {
262+
const float v = xi[j];
263+
const float av = fabsf(xi[j]);
264+
if (amax < av) {
265+
amax = av;
266+
vmax = v;
267+
}
268+
}
269+
270+
const float d = vmax / -32;
271+
const float id = d ? 1.0f/d : 0.0f;
272+
273+
dsti->d = d;
274+
memset(dsti->qh, 0, QK6_0/4);
275+
276+
for (int j = 0; j < QK6_0/2; ++j) {
277+
const float x0 = xi[0 + j]*id;
278+
const float x1 = xi[QK4_0/2 + j]*id;
279+
280+
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
281+
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
282+
283+
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
284+
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
285+
dsti->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
286+
}
287+
}
288+
289+
static __device__ const int8_t iq4nl_index[241] = {
290+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
291+
1, 17, 17, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 18, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
292+
3, 3, 3, 3, 3, 3, 19, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 20, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
293+
5, 5, 21, 21, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 22, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 23, 23, 8, 8, 8, 8,
294+
8, 8, 8, 8, 8, 8, 24, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 25, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 26, 26,
295+
11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 27, 27, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 28, 13, 13, 13,
296+
13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 29, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,
297+
14, 14, 14, 14, 30, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15
298+
};
299+
300+
static __device__ __forceinline__ int best_index_iq4nl(const int8_t * values, float x) {
301+
int ix = (int)x - values[0];
302+
if (ix < 0 || ix >= 241) return ix < 0 ? 0 : 15;
303+
ix = iq4nl_index[ix];
304+
return ix < 16 ? ix : x - values[ix-16] < values[ix-15] - x ? ix-16 : ix-15;
305+
}
306+
254307
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
255308
const float * xi = (const float *) cxi;
256309
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
@@ -269,12 +322,14 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
269322
float d = vmax / kvalues_iq4nl[0];
270323
const float id = d ? 1.0f/d : 0.0f;
271324

325+
//dsti->d = d;
326+
272327
float sumqx = 0, sumq2 = 0;
273328
for (int j = 0; j < QK4_NL/2; ++j) {
274329
const float x0 = xi[0 + j]*id;
275330
const float x1 = xi[QK4_NL/2 + j]*id;
276-
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
277-
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
331+
const uint8_t xi0 = best_index_iq4nl(kvalues_iq4nl, x0);
332+
const uint8_t xi1 = best_index_iq4nl(kvalues_iq4nl, x1);
278333
dsti->qs[j] = xi0 | (xi1 << 4);
279334
const float v0 = kvalues_iq4nl[xi0];
280335
const float v1 = kvalues_iq4nl[xi1];
@@ -486,6 +541,17 @@ static void ggml_cpy_q5_1_f32_cuda(
486541
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
487542
}
488543

544+
static void ggml_cpy_f32_q6_0_cuda(
545+
const char * cx, char * cdst, const int ne,
546+
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
547+
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
548+
549+
GGML_ASSERT(ne % QK6_0 == 0);
550+
const int num_blocks = ne / QK6_0;
551+
cpy_f32_q<cpy_blck_f32_q6_0, QK6_0><<<num_blocks, 1, 0, stream>>>
552+
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
553+
}
554+
489555
static void ggml_cpy_f32_iq4_nl_cuda(
490556
const char * cx, char * cdst, const int ne,
491557
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -573,6 +639,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
573639
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
574640
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
575641
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
642+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
643+
ggml_cpy_f32_q6_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
576644
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
577645
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
578646
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -617,6 +685,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
617685
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
618686
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
619687
return (void*) cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1>;
688+
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q6_0) {
689+
return (void*) cpy_f32_q<cpy_blck_f32_q6_0, QK6_0>;
620690
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
621691
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
622692
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {

0 commit comments

Comments
 (0)