Skip to content

Commit 5b189f5

Browse files
committed
Support RoPE position info in batch prefill/decode kernels
This PR adds q/k position information to batch prefill/decode kernels. More specifically, the kernel now accepts two additional arrays: * `q_rope_position` with shape `(total_q_len,)`, denoting the in-sequence position of each position in the input q. * `k_rope_pos_offset` with shape `(num_sequence,)`, denoting the start position of each sequence in k. These two arrays helps on-the-fly calculate RoPE in multi-level cases. Tests `test_batch_prefill` and `test_batch_decode` can pass. Performance is not validated yet. Per discussion with Zihao, this change is not very likely to incur significant perf regression.
1 parent 08aee43 commit 5b189f5

File tree

10 files changed

+286
-127
lines changed

10 files changed

+286
-127
lines changed

include/flashinfer/decode.cuh

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,8 @@ template <bool partition_kv, RotaryMode rotary_mode, uint32_t num_stages_smem,
497497
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
498498
typename IdType>
499499
__global__ void BatchDecodeWithPagedKVCacheKernel(
500-
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
500+
DTypeIn* __restrict__ q, IdType* __restrict__ q_rope_position,
501+
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
501502
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
502503
DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale,
503504
float rope_rcp_theta) {
@@ -520,6 +521,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
520521
: 0;
521522
const uint32_t seq_len =
522523
partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len;
524+
const uint32_t mapped_batch_idx =
525+
partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx;
523526

524527
extern __shared__ uint8_t smem[];
525528
DTypeIn* k_smem = (DTypeIn*)smem;
@@ -541,23 +544,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
541544
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
542545
}
543546
// apply rotary embedding to q matrix
544-
if constexpr (partition_kv) {
545-
q_vec = vec_apply_llama_rope<vec_size, bdx>(
546-
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim,
547-
freq, seq_len - 1);
548-
} else {
549-
q_vec = vec_apply_llama_rope<vec_size, bdx>(
550-
q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, seq_len - 1);
551-
}
547+
q_vec = vec_apply_llama_rope<vec_size, bdx>(
548+
q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq,
549+
q_rope_position == nullptr ? (seq_len - 1) : q_rope_position[mapped_batch_idx]);
552550
} else {
553551
// do not apply rotary embedding to q matrix
554-
if constexpr (partition_kv) {
555-
q_vec.cast_load(
556-
q + (kv_partition_info.batch_idx_map[batch_idx] * num_qo_heads + qo_head_idx) * head_dim +
557-
tx * vec_size);
558-
} else {
559-
q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
560-
}
552+
q_vec.cast_load(q + (mapped_batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
561553
}
562554
block.sync();
563555

@@ -627,7 +619,9 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
627619
block.sync();
628620
compute_qk<rotary_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
629621
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
630-
freq, cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
622+
freq,
623+
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
624+
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
631625
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, sm_scale, s, st);
632626
block.sync();
633627

@@ -1120,7 +1114,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
11201114
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
11211115
RotaryMode ROTARY_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
11221116
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
1123-
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
1117+
DTypeIn* q, IdType* q_rope_position,
1118+
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
11241119
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
11251120
float rope_scale, float rope_theta, cudaStream_t stream) {
11261121
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
@@ -1153,6 +1148,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
11531148
FLASHINFER_CUDA_CALL(
11541149
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
11551150
void* args[] = {(void*)&q,
1151+
(void*)&q_rope_position,
11561152
(void*)&paged_kv,
11571153
(void*)&kv_partition_info,
11581154
(void*)&o,
@@ -1171,6 +1167,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
11711167
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
11721168
partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
11731169
void* args[] = {(void*)&q,
1170+
(void*)&q_rope_position,
11741171
(void*)&paged_kv,
11751172
(void*)&kv_partition_info,
11761173
(void*)&o,
@@ -1212,7 +1209,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
12121209
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
12131210
typename IdType>
12141211
cudaError_t BatchDecodeWithPagedKVCache(
1215-
DTypeIn* q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
1212+
DTypeIn* q, IdType* q_rope_position,
1213+
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
12161214
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
12171215
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
12181216
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
@@ -1228,13 +1226,12 @@ cudaError_t BatchDecodeWithPagedKVCache(
12281226

12291227
SWITCH_GQA_GROUP_SIZE(
12301228
num_qo_heads / num_kv_heads, GROUP_SIZE,
1231-
{SWITCH_HEAD_DIM(
1232-
head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
1233-
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
1234-
kv_layout, ROTARY_MODE, DTypeIn, DTypeOut,
1235-
IdType>(
1236-
q, paged_kv, kv_partition_info, o, tmp, lse, rope_scale, rope_theta, stream);
1237-
})})});
1229+
{SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {SWITCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
1230+
return BatchDecodeWithPagedKVCacheDispatched<
1231+
GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn,
1232+
DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o,
1233+
tmp, lse, rope_scale, rope_theta, stream);
1234+
})})});
12381235

12391236
return cudaSuccess;
12401237
}

include/flashinfer/handler.cuh

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class BatchPrefillHandler {
267267
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
268268
typename IdType>
269269
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
270-
BatchDecodeHandler* handler, DTypeIn* q,
270+
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position,
271271
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
272272
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
273273
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
@@ -293,15 +293,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
293293
throw std::runtime_error(err_msg.str());
294294
}
295295
return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
296-
q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
297-
rope_theta, stream);
296+
q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode,
297+
rope_scale, rope_theta, stream);
298298
}
299299

300300
template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
301301
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
302302
typename DTypeOut, typename IdType>
303303
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
304-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
304+
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
305305
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
306306
uint32_t num_qo_heads, float rope_scale = 1.f, float rope_theta = 1e4,
307307
cudaStream_t stream = nullptr) {
@@ -328,14 +328,14 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
328328
return BatchPrefillWithPagedKVCacheFallbackDispatched<
329329
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
330330
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
331-
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
332-
rope_scale, rope_theta, stream);
331+
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
332+
num_qo_tiles, rope_scale, rope_theta, stream);
333333
} else {
334334
return BatchPrefillWithPagedKVCacheDispatched<
335335
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
336336
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
337-
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
338-
rope_scale, rope_theta, stream);
337+
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
338+
num_qo_tiles, rope_scale, rope_theta, stream);
339339
}
340340
})});
341341
return cudaSuccess;
@@ -344,7 +344,7 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
344344
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
345345
typename IdType>
346346
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
347-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
347+
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
348348
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
349349
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
350350
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
@@ -363,8 +363,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
363363
return BatchPrefillWithPagedKVCacheWrapperDispatched<
364364
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
365365
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
366-
handler, q, qo_indptr, paged_kv, o, lse, num_qo_heads,
367-
rope_scale, rope_theta, stream);
366+
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse,
367+
num_qo_heads, rope_scale, rope_theta, stream);
368368
})})})})});
369369
return cudaSuccess;
370370
}
@@ -374,9 +374,9 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
374374
typename IdType>
375375
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
376376
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
377-
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
378-
const uint32_t num_kv_heads, const float rope_scale = 1.f, const float rope_theta = 1e4,
379-
cudaStream_t stream = nullptr) {
377+
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
378+
const uint32_t batch_size, const uint32_t num_kv_heads, const float rope_scale = 1.f,
379+
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
380380
float* tmp = nullptr;
381381
IdType* request_indices = nullptr;
382382
IdType* tile_indices = nullptr;
@@ -398,18 +398,19 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
398398
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
399399
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
400400
DTypeIn, DTypeOut, IdType>(
401-
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, o, tmp, lse, batch_size,
402-
num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream);
401+
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position,
402+
k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale,
403+
rope_theta, stream);
403404
});
404405
return cudaSuccess;
405406
}
406407

407408
template <typename DTypeIn, typename DTypeOut, typename IdType>
408409
cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
409410
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
410-
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
411-
const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim,
412-
bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
411+
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
412+
const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads,
413+
const uint32_t head_dim, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
413414
bool allow_fp16_qk_reduction = false, const float rope_scale = 1.f,
414415
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
415416
constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD;
@@ -425,8 +426,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
425426
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
426427
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
427428
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
428-
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
429-
num_kv_heads, rope_scale, rope_theta, stream);
429+
handler, q, qo_indptr, k, v, kv_indptr, q_rope_position,
430+
k_rope_pos_offset, o, lse, batch_size, num_kv_heads,
431+
rope_scale, rope_theta, stream);
430432
})})})})});
431433
return cudaSuccess;
432434
}

include/flashinfer/page.cuh

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ struct paged_kv_t {
8888
IdType* indptr;
8989
// [batch_size] The offset of the last page for each request in the batch
9090
IdType* last_page_len;
91+
// [batch_size] The start position of each request in the batch.
92+
IdType* rope_pos_offset;
9193

9294
/*!
9395
* \brief Construct an empty paged key-value cache
@@ -101,7 +103,8 @@ struct paged_kv_t {
101103
indices(nullptr),
102104
ptrs(nullptr),
103105
indptr(nullptr),
104-
last_page_len(nullptr) {}
106+
last_page_len(nullptr),
107+
rope_pos_offset(nullptr) {}
105108

106109
/*!
107110
* \brief Construct a paged key-value cache
@@ -113,20 +116,23 @@ struct paged_kv_t {
113116
* \param indices The page indices array
114117
* \param indptr The page indptr array
115118
* \param last_page_len The offset of the last page for each request in the batch
119+
* \param rope_pos_offset The start position of each request in the batch.
116120
* \note This constructor should only be used when page_storage == kIndices
117121
*/
118122
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
119123
uint32_t head_dim, uint32_t batch_size,
120124
DType* data, IdType* indices, IdType* indptr,
121-
IdType* last_page_len)
125+
IdType* last_page_len,
126+
IdType* rope_pos_offset = nullptr)
122127
: num_heads(num_heads),
123128
page_size(page_size),
124129
head_dim(head_dim),
125130
batch_size(batch_size),
126131
data(data),
127132
indices(indices),
128133
indptr(indptr),
129-
last_page_len(last_page_len) {}
134+
last_page_len(last_page_len),
135+
rope_pos_offset(rope_pos_offset) {}
130136

131137
/*!
132138
* \brief Construct a paged key-value cache
@@ -137,18 +143,22 @@ struct paged_kv_t {
137143
* \param ptrs The array of pointers to each active page
138144
* \param indptr The page indptr array
139145
* \param last_page_len The offset of the last page for each request in the batch
146+
* \param rope_pos_offset The start position of each request in the batch.
140147
* \note This constructor should only be used when page_storage == kIndices
141148
*/
142149
__host__ __device__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size,
143150
uint32_t head_dim, uint32_t batch_size,
144151
DType** ptrs, IdType* indptr,
145-
IdType* last_page_len)
152+
IdType* last_page_len,
153+
IdType* rope_pos_offset = nullptr)
146154
: num_heads(num_heads),
147155
page_size(page_size),
148156
head_dim(head_dim),
149157
batch_size(batch_size),
150158
ptrs(ptrs),
151-
indptr(indptr) {}
159+
indptr(indptr),
160+
last_page_len(last_page_len),
161+
rope_pos_offset(rope_pos_offset) {}
152162

153163
/*!
154164
* \brief Compute the offset of k element in the allocated buffer.

0 commit comments

Comments
 (0)