Skip to content

Commit bd9d356

Browse files
committed
IQ4_NL KVQ for KCPP/Croco
missing templates instances for KVQ IQ4_NL Update fattn.cu for KVQ IQ4_NL Update fattn-vec-f16.cuh for KVQ IQ4_NL Update fattn-vec-f32.cuh for KVQ IQ4_NL CML and Makefile FOR IQ4_NL
1 parent 4219ed3 commit bd9d356

15 files changed

+139
-12
lines changed

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,54 +396,65 @@ extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
396396
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
397397
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
398398
extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
399+
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
400+
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
399401

400402
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
401403
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
402404
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
403405
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
404406
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
405407
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
408+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0);
406409

407-
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
408-
//extern DECL_FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
409-
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
410+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL);
411+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL);
412+
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL);
413+
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL);
410414
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
411-
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
412-
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
415+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
416+
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);
413417

414418
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
415419
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
416420
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
417421
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
418422
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
419423
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
424+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1);
420425

421426
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
422427
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
423428
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
424429
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
425430
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
426431
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
432+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0);
427433

428434
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
429435
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
430436
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
431437
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
432438
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
433439
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
440+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1);
434441

435442
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
436443
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
437444
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
438445
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
439446
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
440447
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
448+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0);
441449

442450
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
443451
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
444452
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
445453
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
446454
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
447455
extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
456+
//extern DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
448457

449458
extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
459+
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
460+
//extern DECL_FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,47 +374,65 @@ extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
374374
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
375375
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
376376
extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
377+
//extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
378+
//extern DECL_FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
377379

378380
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
379381
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
380382
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
381383
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
382384
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
383385
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
386+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0);
387+
388+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL);
389+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL);
390+
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL);
391+
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL);
392+
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
393+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);
394+
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL);
384395

385396
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
386397
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
387398
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
388399
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
389400
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
390401
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
402+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1);
391403

392404
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
393405
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
394406
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
395407
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
396408
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
397409
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
410+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0);
398411

399412
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
400413
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
401414
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
402415
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
403416
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
404417
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
418+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1);
405419

406420
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
407421
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
408422
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
409423
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
410424
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
411425
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
426+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0);
412427

413428
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
414429
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
415430
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
416431
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
417432
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
418433
extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
434+
//extern DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
419435

420436
extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
437+
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL);
438+
//extern DECL_FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
152152
} \
153153

154154
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
155-
ggml_tensor * Q = dst->src[0];
155+
ggml_tensor * Q = dst->src[1];
156156
ggml_tensor * K = dst->src[1];
157157
ggml_tensor * V = dst->src[2];
158158

@@ -162,57 +162,69 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
162162
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0)
163163
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
164164
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
165-
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 )
165+
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
166+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
167+
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
166168

167169
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
168170
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
169171
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
170172
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
171173
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
172174
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
175+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
176+
177+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
178+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
179+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
180+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
181+
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
182+
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
183+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
173184

174185
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
175186
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
176187
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
177188
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
178189
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
179190
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
191+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
180192

181193
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
182194
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
183195
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
184196
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
185197
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
186198
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
199+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
187200

188201
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
189202
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
190203
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
191204
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
192205
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
193206
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
207+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
194208

195209
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
196210
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
197211
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
198212
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
199213
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
200214
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
215+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
201216

202217
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
203218
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
204219
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
205220
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
206221
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
207222
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
223+
FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
208224

209225
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
210226

211-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
212-
//FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
213227
//FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
214-
//FATTN_VEC_F16_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
215-
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
216228
//FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
217229
#else
218230
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
@@ -241,7 +253,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
241253
} \
242254

243255
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
244-
ggml_tensor * Q = dst->src[0];
256+
ggml_tensor * Q = dst->src[1];
245257
ggml_tensor * K = dst->src[1];
246258
ggml_tensor * V = dst->src[2];
247259

@@ -252,50 +264,69 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
252264
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1)
253265
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0)
254266
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
267+
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
268+
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
255269

256270
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
257271
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
258272
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
259273
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
260274
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
261275
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0)
276+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0)
277+
278+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_IQ4_NL)
279+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL)
280+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_IQ4_NL)
281+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_IQ4_NL)
282+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
283+
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
284+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL)
262285

263286
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
264287
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
265288
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
266289
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
267290
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
268291
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1)
292+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1)
269293

270294
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
271295
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
272296
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
273297
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
274298
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
275299
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0)
300+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0)
276301

277302
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
278303
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
279304
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
280305
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
281306
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
282307
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1)
308+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1)
283309

284310
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
285311
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
286312
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
287313
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
288314
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
289315
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0)
316+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0)
290317

291318
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16)
292319
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16)
293320
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16)
294321
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16)
295322
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16)
296323
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
324+
FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16)
297325

298326
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
327+
328+
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
329+
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
299330
#else
300331
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
301332

@@ -304,6 +335,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
304335
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16)
305336
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16)
306337
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
338+
339+
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
340+
//FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
341+
//FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_IQ4_NL)
342+
//FATTN_VEC_F32_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
343+
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
344+
//FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL)
307345
#endif // GGML_CUDA_FA_ALL_QUANTS
308346

309347
on_no_fattn_vec_case(Q->ne[0]);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q4_1);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_1);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f16.cuh"
4+
5+
DECL_FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_Q8_0);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../fattn-vec-f32.cuh"
4+
5+
DECL_FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_F16);

0 commit comments

Comments
 (0)