@@ -308,13 +308,72 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
308308
309309 if (Q->ne [1 ] == 1 ) {
310310 constexpr int cols_per_block = 1 ;
311- constexpr int parallel_blocks = 4 ;
311+ const int total_blocks = (((Q->ne [1 ] + cols_per_block - 1 ) / cols_per_block)*Q->ne [2 ]*Q->ne [3 ]);
312+ const int nsm = ggml_cuda_info ().devices [ggml_cuda_get_device ()].nsm ;
313+ const int seqlen_tiles = (K->ne [1 ] + D - 1 ) / D;
314+
312315 if (logit_softcap == 0 .0f ) {
313316 constexpr bool use_logit_softcap = false ;
314- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
315- } else {
317+
318+ // Determine the number of active blocks per SM
319+ // parallel_blocks template parameter has no effect on the number of active blocks, so keeping a constant 4 to determine active blocks
320+ int numActiveBlocks = 1 ;
321+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
322+
323+ // we want to keep at least `numActiveBlocks` blocks per SM to improve occupancy.
324+ // this kernel operates on `D` tile of seq length. We need to consider how many `D` tiles can be processed in parallel.
325+ // If there are not enough tiles to process, we can reduce the number of blocks
326+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
327+
328+ if (parallel_blocks >= 24 )
329+ {
330+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
331+ }
332+ else if (parallel_blocks >= 16 )
333+ {
334+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
335+ }
336+ else if (parallel_blocks >= 12 )
337+ {
338+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
339+ }
340+ else if (parallel_blocks >= 8 )
341+ {
342+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
343+ }
344+ else
345+ {
346+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
347+ }
348+ }
349+ else
350+ {
316351 constexpr bool use_logit_softcap = true ;
317- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
352+ int numActiveBlocks = 1 ;
353+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&numActiveBlocks, flash_attn_vec_ext_f32<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>, D, 0 ));
354+
355+ const int parallel_blocks = std::min ((nsm * numActiveBlocks) / total_blocks, seqlen_tiles);
356+
357+ if (parallel_blocks >= 24 )
358+ {
359+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 24 , type_K, type_V, use_logit_softcap>(ctx, dst);
360+ }
361+ else if (parallel_blocks >= 16 )
362+ {
363+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 16 , type_K, type_V, use_logit_softcap>(ctx, dst);
364+ }
365+ else if (parallel_blocks >= 12 )
366+ {
367+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 12 , type_K, type_V, use_logit_softcap>(ctx, dst);
368+ }
369+ else if (parallel_blocks >= 8 )
370+ {
371+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 8 , type_K, type_V, use_logit_softcap>(ctx, dst);
372+ }
373+ else
374+ {
375+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, 4 , type_K, type_V, use_logit_softcap>(ctx, dst);
376+ }
318377 }
319378 return ;
320379 }
0 commit comments