@@ -38,44 +38,60 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
3838 uint32_t kv_len, float sm_scale, float rope_scale,
3939 float rope_theta, cudaStream_t stream);
4040
41- template <uint32_t NUM_FRAGS_X , uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42- QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE , bool ALLOW_FP16_QK_REDUCTION,
41+ template <uint32_t num_frags_x , uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42+ QKVLayout KV_LAYOUT, PosEncodingMode pos_encoding_mode , bool ALLOW_FP16_QK_REDUCTION,
4343 MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
4444cudaError_t BatchPrefillWithRaggedKVCacheDispatched (
45- DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
46- DTypeIn* v, IdType* kv_indptr, uint8_t * custom_mask, IdType* qk_indptr, IdType* q_offset,
47- IdType* k_rope_pos_offset, DTypeOut* o, float * tmp, float * lse, uint32_t batch_size,
48- uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
49- float rope_scale, float rope_theta, cudaStream_t stream = nullptr );
45+ DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
46+ IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t * custom_mask,
47+ IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o,
48+ DTypeOut* tmp_v, float * tmp_s, float * lse, IdType* merge_indptr, bool * block_valid_mask,
49+ IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads,
50+ const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale,
51+ const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr );
5052
51- template <PageStorage PAGE_STORAGE , uint32_t NUM_FRAGS_X , uint32_t HEAD_DIM,
52- LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT , PosEncodingMode POS_ENCODING_MODE ,
53+ template <PageStorage page_storage , uint32_t num_frags_x , uint32_t HEAD_DIM,
54+ LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout , PosEncodingMode pos_encoding_mode ,
5355 bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
5456 typename IdType>
5557cudaError_t BatchPrefillWithPagedKVCacheDispatched (
56- DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
57- paged_kv_t <PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
58- IdType* qk_indptr, DTypeOut* o, float * tmp, float * lse, uint32_t num_qo_tiles,
59- uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);
58+ DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
59+ IdType* q_indptr, IdType* q_offset,
60+ paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
61+ IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float * tmp_s, float * lse,
62+ IdType* merge_indptr, bool * block_valid_mask, IdType* kv_chunk_size_ptr,
63+ uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale,
64+ float rope_scale, float rope_theta, cudaStream_t stream);
6065
6166template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
6267 QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
6368 MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
6469cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched (
65- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr , IdType* q_offset,
70+ BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr , IdType* q_offset,
6671 paged_kv_t <PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
6772 IdType* qk_indptr, DTypeOut* o, float * lse, uint32_t num_qo_heads, float sm_scale,
6873 float rope_scale, float rope_theta, cudaStream_t stream) {
69- float * tmp = nullptr ;
70- IdType* request_indices = nullptr ;
71- IdType* tile_indices = nullptr ;
74+ DTypeOut* tmp_v = nullptr ;
75+ float * tmp_s = nullptr ;
76+ IdType *request_indices = nullptr , *qo_tile_indices = nullptr , *kv_tile_indices = nullptr ,
77+ *o_indptr = nullptr , *merge_indptr = nullptr , *kv_chunk_size_ptr = nullptr ;
78+ bool * block_valid_mask = nullptr ;
7279 uint32_t num_frags_x = 0U ;
73- uint32_t num_qo_tiles = 0U ;
80+ uint32_t padded_batch_size = 0U ;
81+ uint32_t total_num_rows = 0U ;
7482 if (handler->IsForwardStarted ()) {
83+ tmp_v = handler->GetTempV <DTypeOut>();
84+ tmp_s = handler->GetTempS ();
7585 request_indices = handler->GetRequestIndices <IdType>();
76- tile_indices = handler->GetTileIndices <IdType>();
86+ qo_tile_indices = handler->GetQOTileIndices <IdType>();
87+ kv_tile_indices = handler->GetKVTileIndices <IdType>();
88+ block_valid_mask = handler->GetBlockValidMask ();
89+ o_indptr = handler->GetOIndptr <IdType>();
90+ merge_indptr = handler->GetMergeIndptr <IdType>();
91+ kv_chunk_size_ptr = handler->GetKVChunkSizePtr <IdType>();
7792 num_frags_x = handler->GetNumFragsX ();
78- num_qo_tiles = handler->GetNumQOTiles ();
93+ padded_batch_size = handler->GetPaddedBatchSize ();
94+ total_num_rows = handler->GetTotalNumRows ();
7995 } else {
8096 std::ostringstream err_msg;
8197 err_msg << " Please call BatchPrefillHandler's BeginForward() before calling "
@@ -87,8 +103,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
87103 return BatchPrefillWithPagedKVCacheDispatched<
88104 PAGE_STORAGE, NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
89105 ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
90- q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o,
91- tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
106+ q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv,
107+ custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask,
108+ kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, sm_scale, rope_scale,
109+ rope_theta, stream);
92110 });
93111 return cudaSuccess;
94112}
@@ -97,21 +115,32 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
97115 PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
98116 typename DTypeIn, typename DTypeOut, typename IdType>
99117cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched (
100- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr , DTypeIn* k, DTypeIn* v,
118+ BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr , DTypeIn* k, DTypeIn* v,
101119 IdType* kv_indptr, uint8_t * custom_mask, IdType* qk_indptr, IdType* q_offset,
102- IdType* k_rope_pos_offset, DTypeOut* o, float * lse, uint32_t batch_size, uint32_t num_qo_heads,
120+ IdType* k_rope_pos_offset, DTypeOut* o, float * lse, uint32_t num_qo_heads,
103121 uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
104122 cudaStream_t stream) {
105- float * tmp = nullptr ;
106- IdType* request_indices = nullptr ;
107- IdType* tile_indices = nullptr ;
123+ DTypeOut* tmp_v = nullptr ;
124+ float * tmp_s = nullptr ;
125+ IdType *request_indices = nullptr , *qo_tile_indices = nullptr , *kv_tile_indices = nullptr ,
126+ *o_indptr = nullptr , *merge_indptr = nullptr , *kv_chunk_size_ptr = nullptr ;
127+ bool * block_valid_mask = nullptr ;
108128 uint32_t num_frags_x = 0U ;
109- uint32_t num_qo_tiles = 0U ;
129+ uint32_t padded_batch_size = 0U ;
130+ uint32_t total_num_rows = 0U ;
110131 if (handler->IsForwardStarted ()) {
132+ tmp_v = handler->GetTempV <DTypeOut>();
133+ tmp_s = handler->GetTempS ();
111134 request_indices = handler->GetRequestIndices <IdType>();
112- tile_indices = handler->GetTileIndices <IdType>();
135+ qo_tile_indices = handler->GetQOTileIndices <IdType>();
136+ kv_tile_indices = handler->GetKVTileIndices <IdType>();
137+ block_valid_mask = handler->GetBlockValidMask ();
138+ o_indptr = handler->GetOIndptr <IdType>();
139+ merge_indptr = handler->GetMergeIndptr <IdType>();
140+ kv_chunk_size_ptr = handler->GetKVChunkSizePtr <IdType>();
113141 num_frags_x = handler->GetNumFragsX ();
114- num_qo_tiles = handler->GetNumQOTiles ();
142+ padded_batch_size = handler->GetPaddedBatchSize ();
143+ total_num_rows = handler->GetTotalNumRows ();
115144 } else {
116145 std::ostringstream err_msg;
117146 err_msg << " Please call BatchPrefillHandler's BeginForward() before calling "
@@ -123,9 +152,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
123152 return BatchPrefillWithRaggedKVCacheDispatched<
124153 NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
125154 ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
126- q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr,
127- q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles,
128- num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
155+ q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr,
156+ custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse,
157+ merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads,
158+ padded_batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
129159 });
130160 return cudaSuccess;
131161}
0 commit comments