Skip to content

Commit 7b04191

Browse files
committed
try fix fattn again, porting some older code. the cc detection is not working well, so its hacky
1 parent 9423de5 commit 7b04191

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,12 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
298298
const int warp_size = ggml_cuda_info().devices[device].warp_size;
299299
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
300300

301+
#if defined(GGML_HIP_ROCWMMA_FATTN)
302+
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { //kcpp: fix for rocwmma
303+
return BEST_FATTN_KERNEL_WMMA_F16;
304+
}
305+
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
306+
301307
switch (K->ne[0]) {
302308
case 64:
303309
case 128:
@@ -415,15 +421,21 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
415421
return BEST_FATTN_KERNEL_WMMA_F16;
416422
}
417423

418-
//kcpp: always force WMMA for older gpus, fix issues like "FlashAttention without tensor cores only supports head sizes 64 and 128."
419-
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_TURING) {
424+
//kcpp: always force WMMA for Turing and Volta if above check fails, fix "FlashAttention without tensor cores only supports head sizes 64 and 128."
425+
if (cc == GGML_CUDA_CC_TURING || cc == GGML_CUDA_CC_VOLTA) {
420426
return BEST_FATTN_KERNEL_WMMA_F16;
421427
}
422428

423429
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
424430
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
431+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
432+
return BEST_FATTN_KERNEL_VEC_F16; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
433+
}
425434
return BEST_FATTN_KERNEL_TILE_F16;
426435
}
436+
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
437+
return BEST_FATTN_KERNEL_VEC_F32; //kcpp: patch from previous version for my sanity. it worked before, idk it should work now.
438+
}
427439
return BEST_FATTN_KERNEL_TILE_F32;
428440
}
429441

0 commit comments

Comments
 (0)