Skip to content

Commit 8f46ee1

Browse files
Nexesenexikawrakow
andcommitted
Allow q8_0 KV cache for head size 256 #330
Co-Authored-By: Kawrakow <iwankawrakow@gmail.com>
1 parent 092f625 commit 8f46ee1

8 files changed

+91
-18
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
211211
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
212212
//FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
213213

214-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
214+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
215215
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
216-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
216+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
217217
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
218218

219219
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -224,14 +224,14 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
224224
// FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
225225
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
226226

227-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228-
// FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230-
// FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
227+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
228+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
229+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
230+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
231+
FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
232232
// FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
233-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234-
// FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
233+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
234+
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
235235

236236
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
237237
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -346,9 +346,9 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
346346
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
347347
//FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
348348

349-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
349+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
350350
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
351-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
351+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
352352
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
353353

354354
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
@@ -358,14 +358,14 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
358358
// FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
359359
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
360360

361-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362-
// FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364-
// FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
361+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
362+
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
363+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
364+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
365+
FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
366366
// FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16)
367-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368-
// FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
367+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0)
368+
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
369369

370370
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
371371
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4048,6 +4048,49 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
40484048
return true;
40494049
case GGML_OP_FLASH_ATTN_EXT:
40504050
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
4051+
/* case GGML_OP_FLASH_ATTN_EXT: {
4052+
#ifndef FLASH_ATTN_AVAILABLE
4053+
return false;
4054+
#endif // FLASH_ATTN_AVAILABLE
4055+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
4056+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
4057+
if (!turing_mma_available(cc)) {
4058+
return false;
4059+
}
4060+
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
4061+
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
4062+
}
4063+
// TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
4064+
if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
4065+
&& op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
4066+
return false;
4067+
}
4068+
if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
4069+
(op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
4070+
(op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
4071+
return true;
4072+
}
4073+
if (op->src[0]->ne[0] == 192) {
4074+
return false;
4075+
}
4076+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
4077+
return false;
4078+
}
4079+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
4080+
return true;
4081+
}
4082+
if (op->src[0]->ne[0] == 128) {
4083+
return true;
4084+
}
4085+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
4086+
return true;
4087+
}
4088+
if (op->src[3] && op->src[3]->ne[2] != 1) {
4089+
return false;
4090+
}
4091+
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
4092+
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
4093+
} */
40514094
case GGML_OP_CROSS_ENTROPY_LOSS:
40524095
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
40534096
case GGML_OP_OPT_STEP_ADAMW:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);

0 commit comments

Comments
 (0)