2424 template cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched< \
2525 PageStorage::kIndices , LAYOUT, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, \
2626 CAUSAL, T, T, int32_t >(BatchPrefillHandler * handler, T* q, int32_t * qo_indptr, \
27+ int32_t * q_rope_position, \
2728 paged_kv_t <PageStorage::kIndices , LAYOUT, T, int32_t > paged_kv, T* o, \
2829 float * lse, float rope_scale, float rope_theta, cudaStream_t stream); \
2930 }
3031
31- #define INST_BatchPrefillRaggedWrapper (T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
32- LAYOUT, ROTARY_MODE) \
33- namespace flashinfer { \
34- template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
35- GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t >( \
36- BatchPrefillHandler * handler, T* q, int32_t * qo_indptr, T* k, T* v, int32_t * kv_indptr, \
37- T* o, float * lse, uint32_t batch_size, uint32_t num_kv_heads, float rope_scale, \
38- float rope_theta, cudaStream_t stream); \
32+ #define INST_BatchPrefillRaggedWrapper (T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \
33+ LAYOUT, ROTARY_MODE) \
34+ namespace flashinfer { \
35+ template cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched< \
36+ GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t >( \
37+ BatchPrefillHandler * handler, T* q, int32_t * qo_indptr, T* k, T* v, int32_t * kv_indptr, \
38+ int32_t * q_rope_position, int32_t * k_rope_pos_offset, T* o, float * lse, uint32_t batch_size, \
39+ uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \
3940 }
4041
4142#define INST_SinglePrefill (T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \
@@ -56,15 +57,15 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
5657 typename IdType>
5758cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched (
5859 BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
59- IdType* kv_indptr, DTypeOut* o, float * lse, const uint32_t batch_size ,
60- const uint32_t num_kv_heads , const float rope_scale , const float rope_theta ,
61- cudaStream_t stream);
60+ IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float * lse ,
61+ const uint32_t batch_size , const uint32_t num_kv_heads , const float rope_scale ,
62+ const float rope_theta, cudaStream_t stream);
6263
6364template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
6465 RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
6566 typename DTypeOut, typename IdType>
6667cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched (
67- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
68+ BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
6869 paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float * lse,
6970 float rope_scale, float rope_theta, cudaStream_t stream);
7071
0 commit comments