2323#endif
2424#include < cuda_runtime.h>
2525
26- #include < optional>
27- #include < tuple>
28-
2926#include " ../cp_async.cuh"
3027#include " ../fastdiv.cuh"
3128#include " ../layout.cuh"
@@ -175,65 +172,41 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T
175172 *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
176173}
177174
178- template <bool produce_v, uint32_t page_size, uint32_t num_warps, uint32_t num_frags_y,
179- uint32_t num_frags_z, PageStorage page_storage, QKVLayout kv_layout, typename DType,
180- typename IdType>
175+ template <bool produce_v, uint32_t num_warps, uint32_t num_frags_y, uint32_t num_frags_z,
176+ PageStorage page_storage, QKVLayout kv_layout, typename DType, typename IdType>
181177__device__ __forceinline__ void page_produce_kv (
182178 smem_t smem, uint32_t * smem_offset,
183179 paged_kv_t <page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
184- const uint32_t page_iter_base , const uint32_t kv_len, const IdType last_indptr) {
180+ const uint32_t packed_page_iter_base , const uint32_t kv_len, const IdType last_indptr) {
185181 constexpr SharedMemFillMode fill_mode =
186182 produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill ;
187183 constexpr uint32_t head_dim = num_frags_y * 16 ;
188184 constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
189185 const uint32_t tx = threadIdx .x , ty = threadIdx .y ;
190186 const uint32_t kv_head_idx = blockIdx .z ;
191187 uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8 ;
192- if constexpr (page_size % 4 == 0 ) {
193- #pragma unroll
194- for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
195- const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 ) / page_size;
196- const uint32_t entry_idx = (4 * num_warps * i + ty * 4 ) % page_size + tx / 8 ;
197- DType* gptr =
198- produce_v
199- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
200- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
201- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
202- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
203- #pragma unroll
204- for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
205- smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
206- *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
207- gptr += 8 * num_elems_per_128b<DType>();
208- }
209- kv_idx += num_warps * 4 ;
210- *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
211- 2 * num_frags_y;
212- }
213- *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
214- } else {
215188#pragma unroll
216- for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
217- const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8 ) / page_size;
218- const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8 ) % page_size;
219- DType* gptr =
220- produce_v
221- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
222- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
223- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
224- (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
225- #pragma unroll
226- for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
227- smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
228- *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
229- gptr += 8 * num_elems_per_128b<DType>();
230- }
231- kv_idx += num_warps * 4 ;
232- *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
233- 2 * num_frags_y;
189+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps; ++i) {
190+ uint32_t page_iter, entry_idx;
191+ paged_kv.page_size .divmod (packed_page_iter_base + ty * 4 + tx / 8 + 4 * num_warps * i,
192+ page_iter, entry_idx);
193+ DType* gptr =
194+ produce_v
195+ ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
196+ (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr)
197+ : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
198+ (tx % 8 ) * num_elems_per_128b<DType>(), last_indptr);
199+ #pragma unroll
200+ for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
201+ smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
202+ *smem_offset = smem.advance_offset_by_column <8 >(*smem_offset, j);
203+ gptr += 8 * num_elems_per_128b<DType>();
234204 }
235- *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
205+ kv_idx += num_warps * 4 ;
206+ *smem_offset = smem.advance_offset_by_row <num_warps * 4 , channel_size_128b_in>(*smem_offset) -
207+ 2 * num_frags_y;
236208 }
209+ *smem_offset -= num_frags_z * 16 * channel_size_128b_in;
237210}
238211
239212template <uint32_t num_frags_y>
@@ -1342,10 +1315,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
13421315 }
13431316}
13441317
1345- template <LogitsPostHook logits_post_hook, uint32_t page_size, MaskMode mask_mode ,
1346- PosEncodingMode pos_encoding_mode , uint32_t num_frags_x , uint32_t num_frags_y ,
1347- uint32_t num_frags_z, uint32_t num_warps, PageStorage page_storage, QKVLayout kv_layout ,
1348- typename DTypeIn, typename DTypeQKAccum, typename DTypeOut, typename IdType>
1318+ template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode pos_encoding_mode ,
1319+ uint32_t num_frags_x , uint32_t num_frags_y , uint32_t num_frags_z, uint32_t num_warps ,
1320+ PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeQKAccum ,
1321+ typename DTypeOut, typename IdType>
13491322__global__ void BatchPrefillWithPagedKVCacheKernel (
13501323 IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
13511324 DTypeIn* __restrict__ q, paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv,
@@ -1448,12 +1421,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
14481421 smem_t ::get_permuted_offset<channel_size_128b_in>(ty * 4 + tx / 8 , tx % 8 );
14491422 const IdType last_indptr = paged_kv.indptr [paged_kv.batch_size ];
14501423
1451- uint32_t page_iter_base = paged_kv.indptr [request_idx];
1452- page_produce_kv<false , page_size, num_warps, num_frags_y, num_frags_z>(
1453- k_smem, &kv_smem_offset_w, paged_kv, 0 , page_iter_base , kv_len, last_indptr);
1424+ uint32_t packed_page_iter_base = paged_kv.indptr [request_idx] * paged_kv. page_size ;
1425+ page_produce_kv<false , num_warps, num_frags_y, num_frags_z>(
1426+ k_smem, &kv_smem_offset_w, paged_kv, 0 , packed_page_iter_base , kv_len, last_indptr);
14541427 cp_async::commit_group ();
1455- page_produce_kv<true , page_size, num_warps, num_frags_y, num_frags_z>(
1456- v_smem, &kv_smem_offset_w, paged_kv, 0 , page_iter_base , kv_len, last_indptr);
1428+ page_produce_kv<true , num_warps, num_frags_y, num_frags_z>(
1429+ v_smem, &kv_smem_offset_w, paged_kv, 0 , packed_page_iter_base , kv_len, last_indptr);
14571430 cp_async::commit_group ();
14581431
14591432 const uint32_t num_iterations = ceil_div (
@@ -1508,10 +1481,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15081481 update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
15091482
15101483 block.sync ();
1511- page_iter_base += 16 * num_frags_z / page_size ;
1512- page_produce_kv<false , page_size, num_warps, num_frags_y, num_frags_z>(
1513- k_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, page_iter_base, kv_len ,
1514- last_indptr);
1484+ packed_page_iter_base += 16 * num_frags_z;
1485+ page_produce_kv<false , num_warps, num_frags_y, num_frags_z>(
1486+ k_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, packed_page_iter_base ,
1487+ kv_len, last_indptr);
15151488 cp_async::commit_group ();
15161489 cp_async::wait_group<1 >();
15171490 block.sync ();
@@ -1521,9 +1494,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15211494 o_frag, d);
15221495
15231496 block.sync ();
1524- page_produce_kv<true , page_size, num_warps, num_frags_y, num_frags_z>(
1525- v_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, page_iter_base, kv_len ,
1526- last_indptr);
1497+ page_produce_kv<true , num_warps, num_frags_y, num_frags_z>(
1498+ v_smem, &kv_smem_offset_w, paged_kv, (iter + 1 ) * 16 * num_frags_z, packed_page_iter_base ,
1499+ kv_len, last_indptr);
15271500 cp_async::commit_group ();
15281501 }
15291502 cp_async::wait_group<0 >();
@@ -1776,7 +1749,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
17761749 return cudaSuccess;
17771750}
17781751
1779- template <PageStorage page_storage, uint32_t num_frags_x, uint32_t PAGE_SIZE, uint32_t HEAD_DIM,
1752+ template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
17801753 LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
17811754 bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
17821755 typename IdType>
@@ -1831,8 +1804,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
18311804 throw std::invalid_argument (err_msg.str ());
18321805 } else {
18331806 auto kernel = BatchPrefillWithPagedKVCacheKernel<
1834- LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y,
1835- num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
1807+ LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z ,
1808+ num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
18361809 uint32_t smem_size =
18371810 (num_frags_x * num_warps + num_frags_z * 2 ) * 16 * HEAD_DIM * sizeof (DTypeIn);
18381811 FLASHINFER_CUDA_CALL (
0 commit comments