@@ -419,21 +419,41 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
419419 return cudaSuccess;
420420}
421421
422+ inline uint32_t DetermineCtaTileQ (int64_t avg_packed_qo_len, uint32_t head_dim) {
423+ if (avg_packed_qo_len > 64 && head_dim < 256 ) {
424+ return 128 ;
425+ } else {
426+ auto compute_capacity = GetCudaComputeCapability ();
427+ if (compute_capacity.first >= 8 ) {
428+ // Ampere or newer
429+ if (avg_packed_qo_len > 16 ) {
430+ // avg_packed_qo_len <= 64
431+ return 64 ;
432+ } else {
433+ // avg_packed_qo_len <= 16
434+ return 16 ;
435+ }
436+ } else {
437+ // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
438+ return 64 ;
439+ }
440+ }
441+ }
442+
422443template <typename IdType>
423- inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
424- uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
425- uint32_t page_size, uint32_t max_batch_size_if_split,
426- bool enable_cuda_graph) {
444+ inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h,
445+ uint32_t total_num_rows, uint32_t max_seq_len,
446+ uint32_t batch_size, uint32_t num_qo_heads,
447+ uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
448+ uint32_t max_batch_size_if_split, bool enable_cuda_graph) {
427449 std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
428450 merge_indptr.push_back (0 );
429451 o_indptr.push_back (0 );
430452
431453 const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
432- uint32_t total_num_rows = qo_indptr_h[batch_size];
433454
434- // step 1: compute qo_chunk_size
455+ // step 1: determine packed_qo_len_arr and verify qo_indptr contents.
435456 std::vector<int64_t > packed_qo_len_arr (batch_size), kv_len_arr (batch_size);
436- int64_t sum_packed_qo_len = 0 ;
437457 for (uint32_t i = 0 ; i < batch_size; ++i) {
438458 packed_qo_len_arr[i] = int64_t (qo_indptr_h[i + 1 ] - qo_indptr_h[i]) * int64_t (gqa_group_size);
439459 if (packed_qo_len_arr[i] < 0 ) {
@@ -449,41 +469,43 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
449469 << kv_indptr_h[i] << " should be non-negative" ;
450470 FLASHINFER_ERROR (err_msg.str ());
451471 }
452- sum_packed_qo_len += packed_qo_len_arr[i];
453472 }
454- int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
473+
474+ // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
455475 uint32_t cta_tile_q;
456- if (avg_packed_qo_len > 64 && head_dim < 256 ) {
457- cta_tile_q = 128 ;
476+ uint32_t total_num_tiles_q;
477+ bool split_kv;
478+ int64_t kv_chunk_size, new_batch_size;
479+ if (enable_cuda_graph) {
480+ // When CUDA graphs are enabled, the lengths of sequences determined by
481+ // qo_indptr_h can vary. We assume that the dummy data based on which
482+ // the CUDA graph is created fixes the maximum number of tokens.
483+ uint64_t max_qo_len = uint64_t (max_seq_len) * gqa_group_size;
484+ cta_tile_q = DetermineCtaTileQ (max_qo_len, head_dim);
485+
486+ // Find an upper bound for the number of tiles, derived from the total
487+ // number of rows and the batch size. The sum of qo lengths rounded
488+ // up to cta_tile_q will not exceed this number derived from the total
489+ // number of rows.
490+ total_num_tiles_q = ceil_div (total_num_rows, cta_tile_q) + batch_size;
491+
492+ split_kv = true ;
493+ kv_chunk_size = max_batch_size_if_split;
494+ new_batch_size = max_batch_size_if_split;
458495 } else {
459- auto compute_capacity = GetCudaComputeCapability ();
460- if (compute_capacity.first >= 8 ) {
461- // Ampere or newer
462- if (avg_packed_qo_len > 16 ) {
463- // avg_packed_qo_len <= 64
464- cta_tile_q = 64 ;
465- } else {
466- // avg_packed_qo_len <= 16
467- cta_tile_q = 16 ;
468- }
469- } else {
470- // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
471- cta_tile_q = 64 ;
496+ total_num_tiles_q = 0 ;
497+ int64_t sum_packed_qo_len = 0 ;
498+ for (uint32_t i = 0 ; i < batch_size; ++i) {
499+ total_num_tiles_q += ceil_div (packed_qo_len_arr[i], cta_tile_q);
500+ sum_packed_qo_len += packed_qo_len_arr[i];
472501 }
473- }
474502
475- uint32_t total_num_tiles_q = 0 ;
476- for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
477- total_num_tiles_q += ceil_div (packed_qo_len_arr[request_idx], cta_tile_q);
478- }
503+ const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
504+ cta_tile_q = DetermineCtaTileQ (avg_packed_qo_len, head_dim);
479505
480- // step 2: determine kv_chunk_size
481- auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize (
482- max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
483- /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
484-
485- if (enable_cuda_graph) {
486- split_kv = total_num_tiles_q < max_batch_size_if_split;
506+ std::tie (split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize (
507+ max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
508+ /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
487509 }
488510
489511 // step 3: split qo_indptr and kv_indptr
@@ -511,7 +533,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
511533 kv_chunk_size *= page_size;
512534
513535 return std::make_tuple (split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
514- total_num_rows, std::move (request_indices), std::move (qo_tile_indices),
536+ std::move (request_indices), std::move (qo_tile_indices),
515537 std::move (kv_tile_indices), std::move (merge_indptr), std::move (o_indptr));
516538}
517539
@@ -597,9 +619,10 @@ template <typename IdType>
597619inline cudaError_t PrefillPlan (void * float_buffer, size_t float_workspace_size_in_bytes,
598620 void * int_buffer, void * page_locked_int_buffer,
599621 size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
600- IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
601- uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
602- uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o,
622+ IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
623+ uint32_t max_seq_len, uint32_t batch_size, uint32_t num_qo_heads,
624+ uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
625+ bool enable_cuda_graph, uint32_t sizeof_dtype_o,
603626 cudaStream_t stream) {
604627 if (num_qo_heads % num_kv_heads != 0 ) {
605628 std::ostringstream err_msg;
@@ -618,17 +641,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
618641 uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
619642
620643 // step 2: determine kv_chunk_size
621- auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows ,
622- request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
623- o_indptr_vec] =
624- PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads,
625- head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
644+ auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec ,
645+ qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
646+ PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, total_num_rows, max_seq_len, batch_size,
647+ num_qo_heads, num_kv_heads, head_dim, page_size ,
648+ max_batch_size_if_split, enable_cuda_graph);
626649 plan_info.cta_tile_q = cta_tile_q;
627650 plan_info.total_num_rows = total_num_rows;
628651
629652 plan_info.enable_cuda_graph = enable_cuda_graph;
630653 size_t padded_batch_size =
631654 enable_cuda_graph ? std::max (max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
655+
632656 plan_info.padded_batch_size = padded_batch_size;
633657 plan_info.split_kv = split_kv;
634658
@@ -679,6 +703,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
679703 sizeof (IdType) * (plan_info.total_num_rows + 1 ), 16 , " batch_prefill_merge_indptr" );
680704 plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset (
681705 sizeof (bool ) * padded_batch_size, 16 , " batch_prefill_block_valid_mask" );
706+
682707 IdType* merge_indptr_h =
683708 GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset );
684709 bool * block_valid_mask_h =
0 commit comments