@@ -198,7 +198,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
198198    return  BEST_FATTN_KERNEL_NONE;
199199#endif //  FLASH_ATTN_AVAILABLE
200200
201-     const  ggml_tensor * KQV   = dst;
202201    const  ggml_tensor * Q     = dst->src [0 ];
203202    const  ggml_tensor * K     = dst->src [1 ];
204203    const  ggml_tensor * V     = dst->src [2 ];
@@ -208,8 +207,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208207    GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
209208
210209    const  int  cc = ggml_cuda_info ().devices [device].cc ;
211-     const  int  warp_size = ggml_cuda_info ().devices [device].warp_size ;
212-     const  enum  ggml_prec prec = ggml_flash_attn_ext_get_prec (KQV);
213210
214211    switch  (K->ne [0 ]) {
215212        case   64 :
@@ -267,29 +264,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267264        return  BEST_FATTN_KERNEL_NONE;
268265    }
269266
270-     const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % ( 2 *warp_size)  == 0 ;
267+     const  bool  can_use_vector_kernel = Q->ne [0 ] <= 256  && Q->ne [0 ] % 64  == 0 ;
271268
272269    //  If Turing tensor cores available, use them except for some cases with batch size 1:
273270    if  (turing_mma_available (cc)) {
274271        best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275272
276-         if  (K->type  == GGML_TYPE_F16 && V->type  == GGML_TYPE_F16) {
277-             if  (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1  && Q->ne [3 ] == 1  && !(gqa_ratio > 4  && K->ne [1 ] >= 8192 )) {
278-                 best = BEST_FATTN_KERNEL_VEC;
279-             }
280-         } else  {
281-             if  (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282-                 if  (Q->ne [1 ] <= 2 ) {
273+         if  (can_use_vector_kernel) {
274+             if  (K->type  == GGML_TYPE_F16 && V->type  == GGML_TYPE_F16) {
275+                 if  (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1  && Q->ne [3 ] == 1  && !(gqa_ratio > 4  && K->ne [1 ] >= 8192 )) {
283276                    best = BEST_FATTN_KERNEL_VEC;
284277                }
285278            } else  {
286-                 if  (Q->ne [1 ] == 1 ) {
287-                     best = BEST_FATTN_KERNEL_VEC;
279+                 if  (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
280+                     if  (Q->ne [1 ] <= 2 ) {
281+                         best = BEST_FATTN_KERNEL_VEC;
282+                     }
283+                 } else  {
284+                     if  (Q->ne [1 ] == 1 ) {
285+                         best = BEST_FATTN_KERNEL_VEC;
286+                     }
288287                }
289288            }
290-         } 
291-         if  ((gqa_ratio %  2  !=  0  || !mask) && Q-> ne [ 1 ] ==  1 ) { 
292-             best = BEST_FATTN_KERNEL_VEC;  //  GQA-specific optimizations in the mma kernel do not apply. 
289+              if  ((gqa_ratio %  2  !=  0  || !mask) && Q-> ne [ 1 ] ==  1 ) { 
290+                 best = BEST_FATTN_KERNEL_VEC;  //  GQA-specific optimizations in the mma kernel do not apply. 
291+             } 
293292        }
294293
295294        return  best;
0 commit comments