@@ -166,51 +166,15 @@ cudaError_t BatchDecodeWithPagedKVCache(
166166 * \note This wrapper function should be only called after we call BeginForward function in the
167167 * BatchDecodeHandler.
168168 */
169- template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
170- PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut, typename IdType>
171- cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched (
172- BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
173- paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
174- float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
175- paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
176- kv_partition_info_t <IdType> kv_partition_info;
177- DTypeOut* tmp = handler->GetTempFloatBuffer <DTypeOut>();
178-
179- if (handler->IsForwardStarted ()) {
180- if (tmp != nullptr ) {
181- // create auxiliary information for cooperative kernels
182- new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition ();
183- new_paged_kv.indptr = handler->GetNewIndPtr <IdType>();
184- new_paged_kv.last_page_len = handler->GetNewLastPageLen <IdType>();
185- kv_partition_info.batch_size_before_partition = handler->GetBatchSizeBeforePartition ();
186- kv_partition_info.chunk_indptr = handler->GetChunkIndPtr <IdType>();
187- kv_partition_info.batch_idx_map = handler->GetBatchIdxMap <IdType>();
188- kv_partition_info.chunk_start_pos = handler->GetChunkStartPos <IdType>();
189- kv_partition_info.seq_lens_before_partition = handler->GetSeqLengthsBeforePartition <IdType>();
190- }
191- } else {
192- std::ostringstream err_msg;
193- err_msg << " Please call BatchDecodeHandler's BeginForward() before calling "
194- " BatchDecodeWithPagedKVCacheWrapper()" ;
195- throw std::runtime_error (err_msg.str ());
196- }
197-
198- return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, kv_layout,
199- pos_encoding_mode, DTypeIn, DTypeOut, IdType>(
200- q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta,
201- stream);
202- return cudaSuccess;
203- }
204-
205- template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
169+ template <PageStorage page_storage, QKVLayout KV_LAYOUT, typename DTypeIn, typename DTypeOut,
206170 typename IdType>
207171cudaError_t BatchDecodeWithPagedKVCacheWrapper (
208172 BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
209- paged_kv_t <page_storage, kv_layout , DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
173+ paged_kv_t <page_storage, KV_LAYOUT , DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
210174 uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone ,
211175 std::optional<float > maybe_sm_scale = std::nullopt , float rope_scale = 1 .f,
212176 float rope_theta = 1e4 , cudaStream_t stream = nullptr ) {
213- const float sm_scale = maybe_sm_scale.value_or (1 .f / std::sqrt (float (paged_kv.head_dim )));
177+ float sm_scale = maybe_sm_scale.value_or (1 .f / std::sqrt (float (paged_kv.head_dim )));
214178 const uint32_t num_kv_heads = paged_kv.num_heads ;
215179 if (num_qo_heads % num_kv_heads != 0 ) {
216180 std::ostringstream err_msg;
@@ -219,18 +183,42 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
219183 throw std::invalid_argument (err_msg.str ());
220184 }
221185
222- // DISPATCH_GQA_GROUP_SIZE(
223- // num_qo_heads / num_kv_heads, GROUP_SIZE,
224- // {DISPATCH_HEAD_DIM(
225- // paged_kv.head_dim, HEAD_DIM,
226- // {DISPATCH_POS_ENCODING_MODE(
227- // pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
228- // return BatchDecodeWithPagedKVCacheWrapperDispatched<
229- // page_storage, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, DTypeIn,
230- // DTypeOut, IdType>(handler, q, q_offset, paged_kv, o, lse, sm_scale,
231- // rope_scale,
232- // rope_theta, stream);
233- // })})})});
186+ DISPATCH_GQA_GROUP_SIZE (
187+ num_qo_heads / num_kv_heads, GROUP_SIZE,
188+ {DISPATCH_HEAD_DIM (
189+ paged_kv.head_dim , HEAD_DIM,
190+ {DISPATCH_POS_ENCODING_MODE (pos_encoding_mode, POS_ENCODING_MODE, {
191+ paged_kv_t <page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
192+ kv_partition_info_t <IdType> kv_partition_info;
193+ DTypeOut* tmp = handler->GetTempFloatBuffer <DTypeOut>();
194+
195+ if (handler->IsForwardStarted ()) {
196+ if (tmp != nullptr ) {
197+ // create auxiliary information for cooperative kernels
198+ new_paged_kv.batch_size = handler->GetBatchSizeAfterPartition ();
199+ new_paged_kv.indptr = handler->GetNewIndPtr <IdType>();
200+ new_paged_kv.last_page_len = handler->GetNewLastPageLen <IdType>();
201+ kv_partition_info.batch_size_before_partition =
202+ handler->GetBatchSizeBeforePartition ();
203+ kv_partition_info.chunk_indptr = handler->GetChunkIndPtr <IdType>();
204+ kv_partition_info.batch_idx_map = handler->GetBatchIdxMap <IdType>();
205+ kv_partition_info.chunk_start_pos = handler->GetChunkStartPos <IdType>();
206+ kv_partition_info.seq_lens_before_partition =
207+ handler->GetSeqLengthsBeforePartition <IdType>();
208+ }
209+ } else {
210+ std::ostringstream err_msg;
211+ err_msg << " Please call BatchDecodeHandler's BeginForward() before calling "
212+ " BatchDecodeWithPagedKVCacheWrapper()" ;
213+ throw std::runtime_error (err_msg.str ());
214+ }
215+
216+ return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
217+ KV_LAYOUT, POS_ENCODING_MODE, DTypeIn,
218+ DTypeOut, IdType>(
219+ q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale,
220+ rope_theta, stream);
221+ })})});
234222 return cudaSuccess;
235223}
236224
0 commit comments