@@ -1530,9 +1530,14 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation(
15301530 (num_blocks_per_sm * num_sm) /
15311531 (num_kv_heads *
15321532 ceil_div (qo_len * group_size, num_rows_per_cta));
1533- uint32_t chunk_size =
1534- max (ceil_div (kv_len, max_num_kv_chunks), 256 );
1535- uint32_t num_chunks = ceil_div (kv_len, chunk_size);
1533+ uint32_t num_chunks;
1534+ if (max_num_kv_chunks > 0 ) {
1535+ uint32_t chunk_size =
1536+ max (ceil_div (kv_len, max_num_kv_chunks), 256 );
1537+ num_chunks = ceil_div (kv_len, chunk_size);
1538+ } else {
1539+ num_chunks = 0 ;
1540+ }
15361541
15371542 max_grid_size = num_blocks_per_sm * num_sm;
15381543 if (num_chunks > 1 ) {
@@ -1626,8 +1631,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
16261631 uint32_t max_num_kv_chunks =
16271632 (num_blocks_per_sm * num_sm) /
16281633 (num_kv_heads * ceil_div (qo_len * GROUP_SIZE, num_rows_per_cta));
1629- uint32_t chunk_size = max (ceil_div (kv_len, max_num_kv_chunks), 256 );
1630- uint32_t num_chunks = ceil_div (kv_len, chunk_size);
1634+ uint32_t num_chunks;
1635+ if (max_num_kv_chunks > 0 ) {
1636+ uint32_t chunk_size = max (ceil_div (kv_len, max_num_kv_chunks), 256 );
1637+ num_chunks = ceil_div (kv_len, chunk_size);
1638+ } else {
1639+ num_chunks = 0 ;
1640+ }
16311641
16321642 if (num_chunks <= 1 || tmp == nullptr ) {
16331643 // Enough parallelism, do not split-kv
0 commit comments