@@ -267,7 +267,7 @@ class BatchPrefillHandler {
267267template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
268268 typename IdType>
269269cudaError_t BatchDecodeWithPagedKVCacheWrapper (
270- BatchDecodeHandler* handler, DTypeIn* q,
270+ BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position,
271271 paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
272272 uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone , float rope_scale = 1 .f,
273273 float rope_theta = 1e4 , cudaStream_t stream = nullptr ) {
@@ -293,15 +293,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
293293 throw std::runtime_error (err_msg.str ());
294294 }
295295 return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
296- q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale ,
297- rope_theta, stream);
296+ q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode,
297+ rope_scale, rope_theta, stream);
298298}
299299
300300template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
301301 RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
302302 typename DTypeOut, typename IdType>
303303cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched (
304- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
304+ BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
305305 paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
306306 uint32_t num_qo_heads, float rope_scale = 1 .f, float rope_theta = 1e4 ,
307307 cudaStream_t stream = nullptr ) {
@@ -328,14 +328,14 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
328328 return BatchPrefillWithPagedKVCacheFallbackDispatched<
329329 page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
330330 ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
331- q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles ,
332- rope_scale, rope_theta, stream);
331+ q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
332+ num_qo_tiles, rope_scale, rope_theta, stream);
333333 } else {
334334 return BatchPrefillWithPagedKVCacheDispatched<
335335 page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
336336 ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
337- q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles ,
338- rope_scale, rope_theta, stream);
337+ q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
338+ num_qo_tiles, rope_scale, rope_theta, stream);
339339 }
340340 })});
341341 return cudaSuccess;
@@ -344,7 +344,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
344344template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
345345 typename IdType>
346346cudaError_t BatchPrefillWithPagedKVCacheWrapper (
347- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
347+ BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
348348 paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
349349 uint32_t num_qo_heads, bool causal = true , RotaryMode rotary_mode = RotaryMode::kNone ,
350350 bool allow_fp16_qk_reduction = false , float rope_scale = 1 .f, float rope_theta = 1e4 ,
@@ -363,8 +363,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
363363 return BatchPrefillWithPagedKVCacheWrapperDispatched<
364364 page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
365365 ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
366- handler, q, qo_indptr, paged_kv, o, lse, num_qo_heads ,
367- rope_scale, rope_theta, stream);
366+ handler, q, qo_indptr, q_rope_position, paged_kv, o, lse,
367+ num_qo_heads, rope_scale, rope_theta, stream);
368368 })})})})});
369369 return cudaSuccess;
370370}
@@ -374,9 +374,9 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
374374 typename IdType>
375375cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched (
376376 BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
377- IdType* kv_indptr, DTypeOut* o, float * lse, const uint32_t batch_size ,
378- const uint32_t num_kv_heads , const float rope_scale = 1 .f , const float rope_theta = 1e4 ,
379- cudaStream_t stream = nullptr ) {
377+ IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float * lse ,
378+ const uint32_t batch_size , const uint32_t num_kv_heads , const float rope_scale = 1 .f ,
379+ const float rope_theta = 1e4 , cudaStream_t stream = nullptr ) {
380380 float * tmp = nullptr ;
381381 IdType* request_indices = nullptr ;
382382 IdType* tile_indices = nullptr ;
@@ -398,18 +398,19 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
398398 return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
399399 ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
400400 DTypeIn, DTypeOut, IdType>(
401- q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, o, tmp, lse, batch_size,
402- num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream);
401+ q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position,
402+ k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale,
403+ rope_theta, stream);
403404 });
404405 return cudaSuccess;
405406}
406407
407408template <typename DTypeIn, typename DTypeOut, typename IdType>
408409cudaError_t BatchPrefillWithRaggedKVCacheWrapper (
409410 BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
410- IdType* kv_indptr, DTypeOut* o, float * lse, const uint32_t batch_size ,
411- const uint32_t num_qo_heads , const uint32_t num_kv_heads , const uint32_t head_dim ,
412- bool causal = true , RotaryMode rotary_mode = RotaryMode::kNone ,
411+ IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float * lse ,
412+ const uint32_t batch_size , const uint32_t num_qo_heads , const uint32_t num_kv_heads ,
413+ const uint32_t head_dim, bool causal = true , RotaryMode rotary_mode = RotaryMode::kNone ,
413414 bool allow_fp16_qk_reduction = false , const float rope_scale = 1 .f,
414415 const float rope_theta = 1e4 , cudaStream_t stream = nullptr ) {
415416 constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD ;
@@ -425,8 +426,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
425426 return BatchPrefillWithRaggedKVCacheWrapperDispatched<
426427 GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
427428 ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
428- handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
429- num_kv_heads, rope_scale, rope_theta, stream);
429+ handler, q, qo_indptr, k, v, kv_indptr, q_rope_position,
430+ k_rope_pos_offset, o, lse, batch_size, num_kv_heads,
431+ rope_scale, rope_theta, stream);
430432 })})})})});
431433 return cudaSuccess;
432434}
0 commit comments