@@ -419,8 +419,30 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in
419419 return cudaSuccess;
420420}
421421
422+ inline uint32_t DetermineCtaTileQ (int64_t avg_packed_qo_len, uint32_t head_dim) {
423+ if (avg_packed_qo_len > 64 && head_dim < 256 ) {
424+ return 128 ;
425+ } else {
426+ auto compute_capacity = GetCudaComputeCapability ();
427+ if (compute_capacity.first >= 8 ) {
428+ // Ampere or newer
429+ if (avg_packed_qo_len > 16 ) {
430+ // avg_packed_qo_len <= 64
431+ return 64 ;
432+ } else {
433+ // avg_packed_qo_len <= 16
434+ return 16 ;
435+ }
436+ } else {
437+ // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
438+ return 64 ;
439+ }
440+ }
441+ }
442+
422443template <typename IdType>
423- inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
444+ inline auto PrefillSplitQOKVIndptr (IdType* qo_indptr_h, IdType* kv_indptr_h,
445+ uint32_t total_num_rows, uint32_t batch_size,
424446 uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
425447 uint32_t page_size, uint32_t max_batch_size_if_split,
426448 bool enable_cuda_graph) {
@@ -429,11 +451,9 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
429451 o_indptr.push_back (0 );
430452
431453 const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
432- uint32_t total_num_rows = qo_indptr_h[batch_size];
433454
434- // step 1: compute qo_chunk_size
455+ // step 1: determine packed_qo_len_arr and verify qo_indptr contents.
435456 std::vector<int64_t > packed_qo_len_arr (batch_size), kv_len_arr (batch_size);
436- int64_t sum_packed_qo_len = 0 ;
437457 for (uint32_t i = 0 ; i < batch_size; ++i) {
438458 packed_qo_len_arr[i] = int64_t (qo_indptr_h[i + 1 ] - qo_indptr_h[i]) * int64_t (gqa_group_size);
439459 if (packed_qo_len_arr[i] < 0 ) {
@@ -449,41 +469,39 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
449469 << kv_indptr_h[i] << " should be non-negative" ;
450470 FLASHINFER_ERROR (err_msg.str ());
451471 }
452- sum_packed_qo_len += packed_qo_len_arr[i];
453472 }
454- int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
473+
474+ // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
455475 uint32_t cta_tile_q;
456- if (avg_packed_qo_len > 64 && head_dim < 256 ) {
457- cta_tile_q = 128 ;
476+ uint32_t total_num_tiles_q;
477+ bool split_kv;
478+ int64_t kv_chunk_size, new_batch_size;
479+ if (enable_cuda_graph) {
480+ // When CUDA graphs are enabled, the lengths of sequences determined by
481+ // qo_indptr_h can vary. We assume that the dummy data based on which
482+ // the CUDA graph is created fixes the maximum number of tokens.
483+ uint64_t max_qo_len = uint64_t (total_num_rows) * gqa_group_size;
484+
485+ cta_tile_q = DetermineCtaTileQ (max_qo_len, head_dim);
486+ total_num_tiles_q = ceil_div (max_qo_len, cta_tile_q) * batch_size;
487+
488+ split_kv = true ;
489+ kv_chunk_size = max_batch_size_if_split;
490+ new_batch_size = max_batch_size_if_split;
458491 } else {
459- auto compute_capacity = GetCudaComputeCapability ();
460- if (compute_capacity.first >= 8 ) {
461- // Ampere or newer
462- if (avg_packed_qo_len > 16 ) {
463- // avg_packed_qo_len <= 64
464- cta_tile_q = 64 ;
465- } else {
466- // avg_packed_qo_len <= 16
467- cta_tile_q = 16 ;
468- }
469- } else {
470- // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
471- cta_tile_q = 64 ;
492+ total_num_tiles_q = 0 ;
493+ int64_t sum_packed_qo_len = 0 ;
494+ for (uint32_t i = 0 ; i < batch_size; ++i) {
495+ total_num_tiles_q += ceil_div (packed_qo_len_arr[i], cta_tile_q);
496+ sum_packed_qo_len += packed_qo_len_arr[i];
472497 }
473- }
474-
475- uint32_t total_num_tiles_q = 0 ;
476- for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
477- total_num_tiles_q += ceil_div (packed_qo_len_arr[request_idx], cta_tile_q);
478- }
479498
480- // step 2: determine kv_chunk_size
481- auto [split_kv, kv_chunk_size, new_batch_size] = PrefillBinarySearchKVChunkSize (
482- max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q,
483- /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
499+ const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
500+ cta_tile_q = DetermineCtaTileQ (avg_packed_qo_len, head_dim);
484501
485- if (enable_cuda_graph) {
486- split_kv = total_num_tiles_q < max_batch_size_if_split;
502+ std::tie (split_kv, kv_chunk_size, new_batch_size) = PrefillBinarySearchKVChunkSize (
503+ max_batch_size_if_split, packed_qo_len_arr, kv_len_arr, cta_tile_q, page_size,
504+ /* min_kv_chunk_size=*/ std::max ((128 / page_size), 1U ));
487505 }
488506
489507 // step 3: split qo_indptr and kv_indptr
@@ -511,7 +529,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
511529 kv_chunk_size *= page_size;
512530
513531 return std::make_tuple (split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size,
514- total_num_rows, std::move (request_indices), std::move (qo_tile_indices),
532+ std::move (request_indices), std::move (qo_tile_indices),
515533 std::move (kv_tile_indices), std::move (merge_indptr), std::move (o_indptr));
516534}
517535
@@ -597,10 +615,10 @@ template <typename IdType>
597615inline cudaError_t PrefillPlan (void * float_buffer, size_t float_workspace_size_in_bytes,
598616 void * int_buffer, void * page_locked_int_buffer,
599617 size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
600- IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size ,
601- uint32_t num_qo_heads , uint32_t num_kv_heads , uint32_t head_dim ,
602- uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o ,
603- cudaStream_t stream) {
618+ IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows ,
619+ uint32_t batch_size , uint32_t num_qo_heads , uint32_t num_kv_heads ,
620+ uint32_t head_dim, uint32_t page_size, bool enable_cuda_graph ,
621+ uint32_t sizeof_dtype_o, cudaStream_t stream) {
604622 if (num_qo_heads % num_kv_heads != 0 ) {
605623 std::ostringstream err_msg;
606624 err_msg << " num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
@@ -618,17 +636,18 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
618636 uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
619637
620638 // step 2: determine kv_chunk_size
621- auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, total_num_rows ,
622- request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec,
623- o_indptr_vec] =
624- PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, batch_size, num_qo_heads, num_kv_heads ,
625- head_dim, page_size, max_batch_size_if_split, enable_cuda_graph);
639+ auto [split_kv, total_num_tiles_q, new_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec ,
640+ qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
641+ PrefillSplitQOKVIndptr (qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
642+ num_kv_heads, head_dim, page_size, max_batch_size_if_split ,
643+ enable_cuda_graph);
626644 plan_info.cta_tile_q = cta_tile_q;
627645 plan_info.total_num_rows = total_num_rows;
628646
629647 plan_info.enable_cuda_graph = enable_cuda_graph;
630648 size_t padded_batch_size =
631649 enable_cuda_graph ? std::max (max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
650+
632651 plan_info.padded_batch_size = padded_batch_size;
633652 plan_info.split_kv = split_kv;
634653
@@ -679,6 +698,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
679698 sizeof (IdType) * (plan_info.total_num_rows + 1 ), 16 , " batch_prefill_merge_indptr" );
680699 plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset (
681700 sizeof (bool ) * padded_batch_size, 16 , " batch_prefill_block_valid_mask" );
701+
682702 IdType* merge_indptr_h =
683703 GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset );
684704 bool * block_valid_mask_h =
0 commit comments