Skip to content

Commit 93155ca

Browse files
committed
Fix the maximal grid dimension in prefill planning with CUDA graphs
Previously, differences in the contents of qo_indptr could lead to block sizes varying across CUDA graph invocations, leading to illegal memory accessed. This PR alters the calculation of the block size to find a reasonable maximum based on the longest sequence. The maximum token count is fixed in `plan` on the `Python` side and passed along to `scheduler.cuh` to derive the other parameters.
1 parent 5fe9f7d commit 93155ca

File tree

12 files changed

+149
-96
lines changed

12 files changed

+149
-96
lines changed

include/flashinfer/attention/scheduler.cuh

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,30 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
419419
return cudaSuccess;
420420
}
421421

422+
inline uint32_t DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) {
423+
if (avg_packed_qo_len > 64 && head_dim < 256) {
424+
return 128;
425+
} else {
426+
auto compute_capacity = GetCudaComputeCapability();
427+
if (compute_capacity.first >= 8) {
428+
// Ampere or newer
429+
if (avg_packed_qo_len > 16) {
430+
// avg_packed_qo_len <= 64
431+
return 64;
432+
} else {
433+
// avg_packed_qo_len <= 16
434+
return 16;
435+
}
436+
} else {
437+
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
438+
return 64;
439+
}
440+
}
441+
}
442+
422443
template <typename IdType>
423-
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
444+
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
445+
uint32_t total_num_rows, uint32_t batch_size,
424446
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
425447
uint32_t page_size, uint32_t max_batch_size_if_split,
426448
bool enable_cuda_graph) {
@@ -429,11 +451,9 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
429451
o_indptr.push_back(0);
430452

431453
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
432-
uint32_t total_num_rows = qo_indptr_h[batch_size];
433454

434-
// step 1: compute qo_chunk_size
455+
// step 1: determine packed_qo_len_arr and verify qo_indptr contents.
435456
std::vector<int64_t> packed_qo_len_arr(batch_size), kv_len_arr(batch_size);
436-
int64_t sum_packed_qo_len = 0;
437457
for (uint32_t i = 0; i < batch_size; ++i) {
438458
packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size);
439459
if (packed_qo_len_arr[i] < 0) {
@@ -449,41 +469,39 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
449469
<< kv_indptr_h[i] << " should be non-negative";
450470
FLASHINFER_ERROR(err_msg.str());
451471
}
452-
sum_packed_qo_len += packed_qo_len_arr[i];
453472
}
454-
int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
473+
474+
// step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
455475
uint32_t cta_tile_q;
456-
if (avg_packed_qo_len > 64 && head_dim < 256) {
457-
cta_tile_q = 128;
476+
uint32_t total_num_tiles_q;
477+
bool split_kv;
478+
int64_t kv_chunk_size, new_batch_size;
479+
if (enable_cuda_graph) {
480+
// When CUDA graphs are enabled, the lengths of sequences determined by
481+
// qo_indptr_h can vary. We assume that the dummy data based on which
482+
// the CUDA graph is created fixes the maximum number of tokens.
483+
uint64_t max_qo_len = uint64_t(total_num_rows) * gqa_group_size;
484+
485+
cta_tile_q = DetermineCtaTileQ(max_qo_len, head_dim);
486+
total_num_tiles_q = ceil_div(max_qo_len, cta_tile_q) * batch_size;
487+
488+
split_kv = true;
489+
kv_chunk_size = max_batch_size_if_split;
490+
new_batch_size = max_batch_size_if_split;
458491
} else {
459-
auto compute_capacity = GetCudaComputeCapability();
460-
if (compute_capacity.first >= 8) {
461-
// Ampere or newer
462-
if (avg_packed_qo_len > 16) {
463-
// avg_packed_qo_len <= 64
464-
cta_tile_q = 64;
465-
} else {
466-
// avg_packed_qo_len <= 16
467-
cta_tile_q = 16;
468-
}
469-
} else {
470-
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
471-
cta_tile_q = 64;
492+
total_num_tiles_q = 0;
493+
int64_t sum_packed_qo_len = 0;
494+
for (uint32_t i = 0; i < batch_size; ++i) {
495+
total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
496+
sum_packed_qo_len += packed_qo_len_arr[i];
472497
}
473-
}
474-
475-
uint32_t total_num_tiles_q = 0;
476-
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
477-
total_num_tiles_q += ceil_div(packed_qo_len_arr[request_idx], cta_tile_q);
478-
}
479498

480-
// step 2: determine kv_chunk_size
481-
auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize(
482-
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
483-
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));
499+
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
500+
cta_tile_q = DetermineCtaTileQ(avg_packed_qo_len, head_dim);
484501

485-
if (enable_cuda_graph) {
486-
split_kv = total_num_tiles_q < max_batch_size_if_split;
502+
std::tie(split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize(
503+
max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, page_size,
504+
/*min_kv_chunk_size=*/std::max((128 / page_size), 1U));
487505
}
488506

489507
// step 3: split qo_indptr and kv_indptr
@@ -511,7 +529,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
511529
kv_chunk_size *= page_size;
512530

513531
return std::make_tuple(split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
514-
total_num_rows, std::move(request_indices), std::move(qo_tile_indices),
532+
std::move(request_indices), std::move(qo_tile_indices),
515533
std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr));
516534
}
517535

@@ -597,10 +615,10 @@ template <typename IdType>
597615
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
598616
void* int_buffer, void* page_locked_int_buffer,
599617
size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
600-
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
601-
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
602-
uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o,
603-
cudaStream_t stream) {
618+
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
619+
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
620+
uint32_t head_dim, uint32_t page_size, bool enable_cuda_graph,
621+
uint32_t sizeof_dtype_o, cudaStream_t stream) {
604622
if (num_qo_heads % num_kv_heads != 0) {
605623
std::ostringstream err_msg;
606624
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
@@ -618,17 +636,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
618636
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
619637

620638
// step 2: determine kv_chunk_size
621-
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows,
622-
request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
623-
o_indptr_vec] =
624-
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads,
625-
head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
639+
auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
640+
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
641+
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
642+
num_kv_heads, head_dim, page_size, max_batch_size_if_split,
643+
enable_cuda_graph);
626644
plan_info.cta_tile_q = cta_tile_q;
627645
plan_info.total_num_rows = total_num_rows;
628646

629647
plan_info.enable_cuda_graph = enable_cuda_graph;
630648
size_t padded_batch_size =
631649
enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
650+
632651
plan_info.padded_batch_size = padded_batch_size;
633652
plan_info.split_kv = split_kv;
634653

@@ -679,6 +698,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
679698
sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr");
680699
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset(
681700
sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask");
701+
682702
IdType* merge_indptr_h =
683703
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset);
684704
bool* block_valid_mask_h =

python/csrc/batch_prefill.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,9 @@ using namespace flashinfer;
4242
std::vector<int64_t> BatchPrefillWithKVCachePlan(
4343
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
4444
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
45-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
46-
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream) {
45+
unsigned int total_num_rows, unsigned int batch_size, unsigned int num_qo_heads,
46+
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph,
47+
int64_t cuda_stream) {
4748
size_t float_workspace_size_in_bytes =
4849
float_workspace_buffer.size(0) * float_workspace_buffer.element_size();
4950
size_t int_workspace_size_in_bytes =
@@ -58,8 +59,8 @@ std::vector<int64_t> BatchPrefillWithKVCachePlan(
5859
float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes,
5960
int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(),
6061
int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(),
61-
kv_indptr.data_ptr<IdType>(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size,
62-
enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
62+
kv_indptr.data_ptr<IdType>(), total_num_rows, batch_size, num_qo_heads, num_kv_heads,
63+
head_dim, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream);
6364

6465
TORCH_CHECK(status == cudaSuccess,
6566
"Failed to plan prefill with error: ", cudaGetErrorString(status));

python/csrc/flashinfer_ops.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ void single_prefill_with_kv_cache(unsigned int mask_mode_code, at::Tensor q, at:
9999
std::vector<int64_t> BatchPrefillWithKVCachePlan(
100100
unsigned int head_dim, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
101101
at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr,
102-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
103-
unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
102+
unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads,
103+
unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph, int64_t cuda_stream);
104104

105105
void BatchPrefillWithRaggedKVCacheRun(
106106
unsigned int mask_mode_code, at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,

python/flashinfer/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,7 @@ def plan(
777777
qo_indptr_host,
778778
indptr_host,
779779
batch_size,
780+
batch_size,
780781
num_qo_heads,
781782
num_kv_heads,
782783
page_size,

python/flashinfer/jit/batch_prefill_templ.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
139139
at::Tensor page_locked_int_workspace_buffer,
140140
at::Tensor qo_indptr,
141141
at::Tensor kv_indptr,
142+
unsigned int total_num_rows,
142143
unsigned int batch_size,
143144
unsigned int num_qo_heads,
144145
unsigned int num_kv_heads,
@@ -457,6 +458,7 @@ def paged_prefill_inst_templ(mask_mode: str) -> str:
457458
at::Tensor page_locked_int_workspace_buffer,
458459
at::Tensor qo_indptr,
459460
at::Tensor kv_indptr,
461+
unsigned int total_num_rows,
460462
unsigned int batch_size,
461463
unsigned int num_qo_heads,
462464
unsigned int num_kv_heads,

python/flashinfer/prefill.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ def __init__(
833833
self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf
834834
self._custom_mask_buf = custom_mask_buf
835835
self._qk_indptr_buf = qk_indptr_buf
836+
self._total_num_rows = None
836837

837838
@property
838839
def is_cuda_graph_enabled(self) -> bool:
@@ -993,7 +994,21 @@ def plan(
993994
bitorder="little",
994995
)
995996

997+
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
998+
qo_indptr_host = qo_indptr.to("cpu")
999+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1000+
total_num_rows = qo_indptr_host[-1]
1001+
9961002
if self.is_cuda_graph_enabled:
1003+
if self._total_num_rows is None:
1004+
self._total_num_rows = total_num_rows
1005+
if total_num_rows > self._total_num_rows:
1006+
raise ValueError(
1007+
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
1008+
"exceed the number of rows set during initialization {}.".format(
1009+
total_num_rows, self._total_num_rows
1010+
)
1011+
)
9971012
if batch_size != self._fixed_batch_size:
9981013
raise ValueError(
9991014
"The batch size should be fixed during the lifecycle of the wrapper in "
@@ -1031,6 +1046,7 @@ def plan(
10311046
# NOTE(Zihao): qk_indptr has the same length as qo_indptr
10321047
self._qk_indptr_buf.copy_(qk_indptr, non_blocking=non_blocking)
10331048
else:
1049+
self._total_num_rows = total_num_rows
10341050
self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking)
10351051
self._paged_kv_indptr_buf = paged_kv_indptr.to(
10361052
self.device, non_blocking=non_blocking
@@ -1049,10 +1065,6 @@ def plan(
10491065
self.device, non_blocking=non_blocking
10501066
)
10511067

1052-
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1053-
qo_indptr_host = qo_indptr.to("cpu")
1054-
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1055-
10561068
self._cached_q_data_type = q_data_type
10571069
self._cached_kv_data_type = kv_data_type
10581070
self._cached_module = get_batch_prefill_module(
@@ -1073,6 +1085,7 @@ def plan(
10731085
self._pin_memory_int_workspace_buffer,
10741086
qo_indptr_host,
10751087
paged_kv_indptr_host,
1088+
total_num_rows,
10761089
batch_size,
10771090
num_qo_heads,
10781091
num_kv_heads,
@@ -1463,6 +1476,7 @@ def __init__(
14631476
self._kv_indptr_buf = kv_indptr_buf
14641477
self._custom_mask_buf = custom_mask_buf
14651478
self._qk_indptr_buf = qk_indptr_buf
1479+
self._total_num_rows = None
14661480

14671481
@property
14681482
def is_cuda_graph_enabled(self) -> bool:
@@ -1610,7 +1624,21 @@ def plan(
16101624
bitorder="little",
16111625
)
16121626

1627+
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1628+
qo_indptr_host = qo_indptr.to("cpu")
1629+
paged_kv_indptr_host = paged_kv_indptr.to("cpu")
1630+
total_num_rows = qo_indptr_host[-1]
1631+
16131632
if self.is_cuda_graph_enabled:
1633+
if self._total_num_rows is None:
1634+
self._total_num_rows = total_num_rows
1635+
if total_num_rows > self._total_num_rows:
1636+
raise ValueError(
1637+
"The total number of rows in qo_indptr {} in cuda graph mode cannot "
1638+
"exceed the number of rows set during initialization {}.".format(
1639+
total_num_rows, self._total_num_rows
1640+
)
1641+
)
16141642
if batch_size != self._fixed_batch_size:
16151643
raise ValueError(
16161644
"The batch size should be fixed in cudagraph mode, the runtime batch size {} "
@@ -1632,16 +1660,13 @@ def plan(
16321660
self._custom_mask_buf[: len(packed_custom_mask)] = packed_custom_mask
16331661
self._qk_indptr_buf.copy_(qk_indptr)
16341662
else:
1663+
self._total_num_rows = total_num_rows
16351664
self._qo_indptr_buf = qo_indptr.to(self.device)
16361665
self._kv_indptr_buf = kv_indptr.to(self.device)
16371666
if packed_custom_mask is not None:
16381667
self._custom_mask_buf = packed_custom_mask.to(self.device)
16391668
self._qk_indptr_buf = qk_indptr.to(self.device)
16401669

1641-
# NOTE(Zihao): only required if qo_indptr/paged_kv_indptr are device tensors
1642-
qo_indptr_host = qo_indptr.to("cpu")
1643-
kv_indptr_host = kv_indptr.to("cpu")
1644-
16451670
self._cached_q_data_type = q_data_type
16461671
self._cached_kv_data_type = kv_data_type
16471672
self._cached_module = get_batch_prefill_module(
@@ -1662,6 +1687,7 @@ def plan(
16621687
self._pin_memory_int_workspace_buffer,
16631688
qo_indptr_host,
16641689
kv_indptr_host,
1690+
total_num_rows,
16651691
batch_size,
16661692
num_qo_heads,
16671693
num_kv_heads,

src/bench_batch_decode.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,11 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
144144
size_t int_workspace_size_in_bytes = 8 * 1024 * 1024;
145145
thrust::device_vector<char> int_buffer(int_workspace_size_in_bytes);
146146

147-
handler.Plan<T, int32_t>((void*)thrust::raw_pointer_cast(float_buffer.data()),
148-
float_workspace_size_in_bytes,
149-
(void*)thrust::raw_pointer_cast(int_buffer.data()),
150-
int_workspace_size_in_bytes, qo_indptr_h.data(), kv_indptr_host.data(),
151-
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
147+
handler.Plan<T, int32_t>(
148+
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
149+
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
150+
qo_indptr_h.data(), kv_indptr_host.data(), qo_indptr_h.back(), batch_size, num_qo_heads,
151+
num_kv_heads, head_dim, page_size);
152152

153153
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
154154
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, TKV, T, int32_t>(

src/bench_batch_prefill.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,12 @@ void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) {
7373

7474
BatchPrefillHandler handler;
7575

76-
handler.Plan<dtype_out>(
77-
thrust::raw_pointer_cast(float_workspace.data()), float_workspace_size_in_bytes,
78-
thrust::raw_pointer_cast(int_workspace.data()), int_workspace_size_in_bytes,
79-
qo_indptr_h.data(), kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim,
80-
/*page_size=*/1);
76+
handler.Plan<dtype_out>(thrust::raw_pointer_cast(float_workspace.data()),
77+
float_workspace_size_in_bytes,
78+
thrust::raw_pointer_cast(int_workspace.data()),
79+
int_workspace_size_in_bytes, qo_indptr_h.data(), kv_indptr_h.data(),
80+
qo_indptr_h.back(), batch_size, num_qo_heads, num_kv_heads, head_dim,
81+
/*page_size=*/1);
8182

8283
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
8384
timer.start();

src/bench_cascade.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
256256
cascade_handler.Plan<T, int32_t>(
257257
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
258258
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
259-
qo_indptr_h.data(), kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads,
260-
head_dim, page_size);
259+
qo_indptr_h.data(), kv_indptr_unique_h.data(), qo_indptr_h.back(), batch_size, num_qo_heads,
260+
num_kv_heads, head_dim, page_size);
261261
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
262262
timer.start();
263263
cudaError_t status = SinglePrefillWithKVCache(
@@ -317,8 +317,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
317317
baseline_handler.Plan<T, int32_t>(
318318
(void*)thrust::raw_pointer_cast(float_buffer.data()), float_workspace_size_in_bytes,
319319
(void*)thrust::raw_pointer_cast(int_buffer.data()), int_workspace_size_in_bytes,
320-
qo_indptr_h.data(), kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads,
321-
head_dim, page_size);
320+
qo_indptr_h.data(), kv_indptr_combined_h.data(), qo_indptr_h.back(), batch_size,
321+
num_qo_heads, num_kv_heads, head_dim, page_size);
322322
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
323323
timer.start();
324324
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<T, T, T, int32_t>(

0 commit comments

Comments
 (0)