Skip to content

Commit 3d55c71

Browse files
authored
fix: bugfix to pr 135 (#136)
#135 didn't consider the case that `max_num_kv_chunks == 0`, this PR fixes the issue.
1 parent 9b7b0b9 commit 3d55c71

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

include/flashinfer/prefill.cuh

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,9 +1530,14 @@ cudaError_t SinglePrefillWithKVCacheWorkEstimation(
15301530
(num_blocks_per_sm * num_sm) /
15311531
(num_kv_heads *
15321532
ceil_div(qo_len * group_size, num_rows_per_cta));
1533-
uint32_t chunk_size =
1534-
max(ceil_div(kv_len, max_num_kv_chunks), 256);
1535-
uint32_t num_chunks = ceil_div(kv_len, chunk_size);
1533+
uint32_t num_chunks;
1534+
if (max_num_kv_chunks > 0) {
1535+
uint32_t chunk_size =
1536+
max(ceil_div(kv_len, max_num_kv_chunks), 256);
1537+
num_chunks = ceil_div(kv_len, chunk_size);
1538+
} else {
1539+
num_chunks = 0;
1540+
}
15361541

15371542
max_grid_size = num_blocks_per_sm * num_sm;
15381543
if (num_chunks > 1) {
@@ -1626,8 +1631,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
16261631
uint32_t max_num_kv_chunks =
16271632
(num_blocks_per_sm * num_sm) /
16281633
(num_kv_heads * ceil_div(qo_len * GROUP_SIZE, num_rows_per_cta));
1629-
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
1630-
uint32_t num_chunks = ceil_div(kv_len, chunk_size);
1634+
uint32_t num_chunks;
1635+
if (max_num_kv_chunks > 0) {
1636+
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
1637+
num_chunks = ceil_div(kv_len, chunk_size);
1638+
} else {
1639+
num_chunks = 0;
1640+
}
16311641

16321642
if (num_chunks <= 1 || tmp == nullptr) {
16331643
// Enough parallelism, do not split-kv

0 commit comments

Comments
 (0)