@@ -53,7 +53,7 @@ constexpr bool is_invalid_configuration(uint32_t num_frags_x, uint32_t num_frags
5353 uint32_t num_warps_z) {
5454 return ((num_frags_y < 4 ) || (num_frags_y == 4 && num_frags_z % 2 == 1 ) ||
5555 (num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0 ) ||
56- (num_frags_x * (8 * num_frags_y + 2 * sizeof (DTypeQKAccum) * num_frags_z) >= 200 ));
56+ (num_frags_x * (8 * num_frags_y + 2 * sizeof (DTypeQKAccum) * num_frags_z) >= 256 ));
5757}
5858
5959/* !
@@ -207,30 +207,20 @@ template <bool produce_v, uint32_t num_warps_x, uint32_t num_warps_z, uint32_t n
207207__device__ __forceinline__ void page_produce_kv (
208208 smem_t smem, uint32_t * smem_offset,
209209 paged_kv_t <page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
210- const uint32_t packed_page_iter_base , const uint32_t kv_len, const IdType last_indptr ) {
210+ const size_t * kv_offset , const uint32_t kv_len) {
211211 constexpr SharedMemFillMode fill_mode =
212212 produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill ;
213213 constexpr uint32_t head_dim = num_frags_y * 16 ;
214214 constexpr uint32_t num_warps = num_warps_x * num_warps_z;
215215 constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
216216 const uint32_t warp_idx = get_warp_idx<num_warps_x, num_warps_z>(), lane_idx = threadIdx .x ;
217- const uint32_t kv_head_idx = blockIdx .z ;
218217 uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8 ;
219218 // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps
220219 static_assert (num_frags_z * 4 % num_warps_x == 0 );
221220#pragma unroll
222221 for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
223- uint32_t page_iter, entry_idx;
224- paged_kv.page_size .divmod (
225- packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps * i, page_iter,
226- entry_idx);
227- DType* gptr = produce_v
228- ? paged_kv.protective_get_v_ptr (page_iter, kv_head_idx, entry_idx,
229- (lane_idx % 8 ) * num_elems_per_128b<DType>(),
230- last_indptr)
231- : paged_kv.protective_get_k_ptr (page_iter, kv_head_idx, entry_idx,
232- (lane_idx % 8 ) * num_elems_per_128b<DType>(),
233- last_indptr);
222+ DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta () + kv_offset[i]
223+ : paged_kv.data + kv_offset[i];
234224#pragma unroll
235225 for (uint32_t j = 0 ; j < num_frags_y / 4 ; ++j) {
236226 smem.load_128b_async <fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
@@ -800,9 +790,21 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
800790 const uint32_t lane_idx) {
801791 // only necessary when blockDim.z > 1
802792 if constexpr (num_warps_z > 1 ) {
803- float2 * smem_md = (float2 *)smem_workspace;
804- // o: [num_warps, warp_size, 8]
805- // md: [num_warps, num_frags_x, 2, warp_size, 2 (m/d)]
793+ float2 * smem_md = (float2 *)(smem_workspace + num_frags_x * num_frags_y * num_warps_x *
794+ num_warps_z * warp_size * 8 );
795+ // o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8]
796+ // md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)]
797+ #pragma unroll
798+ for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
799+ #pragma unroll
800+ for (uint32_t fy = 0 ; fy < num_frags_y; ++fy) {
801+ vec_t <float , 8 >::memcpy (
802+ smem_workspace +
803+ (((warp_idx * num_frags_x + fx) * num_frags_y + fy) * warp_size + lane_idx) * 8 ,
804+ o_frag[fx][fy]);
805+ }
806+ }
807+
806808#pragma unroll
807809 for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
808810#pragma unroll
@@ -851,23 +853,22 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
851853 }
852854 }
853855
854- __syncthreads ();
855-
856856 // the following code saves shared memory usage.
857857#pragma unroll
858858 for (uint32_t fx = 0 ; fx < num_frags_x; ++fx) {
859859#pragma unroll
860860 for (uint32_t fy = 0 ; fy < num_frags_y; ++fy) {
861861 vec_t <float , 8 > o_new;
862862 o_new.fill (0 .f );
863- vec_t <float , 8 >::memcpy (smem_workspace + (warp_idx * warp_size + lane_idx) * 8 ,
864- o_frag[fx][fy]);
865- __syncthreads ();
866863#pragma unroll
867864 for (uint32_t i = 0 ; i < num_warps_z; ++i) {
868865 vec_t <float , 8 > oi;
869866 oi.load (smem_workspace +
870- ((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * warp_size +
867+ ((((i * num_warps_x + get_warp_idx_x<num_warps_x, num_warps_z>()) * num_frags_x +
868+ fx) *
869+ num_frags_y +
870+ fy) *
871+ warp_size +
871872 lane_idx) *
872873 8 );
873874#pragma unroll
@@ -876,7 +877,6 @@ __device__ __forceinline__ void threadblock_sync_mdo_states(float (*o_frag)[num_
876877 }
877878 }
878879 o_new.store (o_frag[fx][fy]);
879- __syncthreads ();
880880 }
881881 }
882882 }
@@ -1592,6 +1592,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15921592 smem_t k_smem (smem + (num_warps_x * num_frags_x) * 16 * head_dim * sizeof (DTypeIn)),
15931593 v_smem (smem + (num_warps_x * num_frags_x + num_warps_z * num_frags_z) * 16 * head_dim *
15941594 sizeof (DTypeIn));
1595+ size_t kv_offset[num_frags_z * 4 / num_warps_x];
15951596
15961597 uint32_t k_smem_offset_r = smem_t ::get_permuted_offset<channel_size_128b_in>(
15971598 get_warp_idx_z<num_warps_x, num_warps_z>() * num_frags_z * 16 + 8 * (lane_idx / 16 ) +
@@ -1605,13 +1606,22 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16051606 const IdType last_indptr = paged_kv.indptr [paged_kv.batch_size ];
16061607
16071608 uint32_t packed_page_iter_base = paged_kv.indptr [request_idx] * paged_kv.page_size + chunk_start;
1609+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
1610+ uint32_t page_iter, entry_idx;
1611+ paged_kv.page_size .divmod (
1612+ packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1613+ page_iter, entry_idx);
1614+ kv_offset[i] =
1615+ page_iter < last_indptr
1616+ ? paged_kv.get_k_elem_offset (__ldg (paged_kv.indices + page_iter), kv_head_idx,
1617+ entry_idx, (lane_idx % 8 ) * num_elems_per_128b<DTypeIn>())
1618+ : 0 ;
1619+ }
16081620 page_produce_kv<false , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1609- k_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1610- last_indptr);
1621+ k_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
16111622 cp_async::commit_group ();
16121623 page_produce_kv<true , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
1613- v_smem, &kv_smem_offset_w, paged_kv, chunk_start, packed_page_iter_base, chunk_end,
1614- last_indptr);
1624+ v_smem, &kv_smem_offset_w, paged_kv, chunk_start, kv_offset, chunk_end);
16151625 cp_async::commit_group ();
16161626
16171627 const uint32_t num_iterations =
@@ -1631,8 +1641,20 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16311641 : chunk_end - chunk_start) /
16321642 (16 * num_warps_z * num_frags_z);
16331643
1634- #pragma unroll
1644+ #pragma unroll 1
16351645 for (uint32_t iter = 0 ; iter < num_iterations; ++iter) {
1646+ packed_page_iter_base += 16 * num_warps_z * num_frags_z;
1647+ for (uint32_t i = 0 ; i < num_frags_z * 4 / num_warps_x; ++i) {
1648+ uint32_t page_iter, entry_idx;
1649+ paged_kv.page_size .divmod (
1650+ packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
1651+ page_iter, entry_idx);
1652+ kv_offset[i] = page_iter < last_indptr
1653+ ? paged_kv.get_k_elem_offset (
1654+ __ldg (paged_kv.indices + page_iter), kv_head_idx, entry_idx,
1655+ (lane_idx % 8 ) * num_elems_per_128b<DTypeIn>())
1656+ : 0 ;
1657+ }
16361658 cp_async::wait_group<1 >();
16371659 block.sync ();
16381660
@@ -1677,11 +1699,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16771699 update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);
16781700
16791701 block.sync ();
1680- packed_page_iter_base += 16 * num_warps_z * num_frags_z;
16811702 page_produce_kv<false , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
16821703 k_smem, &kv_smem_offset_w, paged_kv,
1683- chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1684- last_indptr);
1704+ chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
16851705 cp_async::commit_group ();
16861706 cp_async::wait_group<1 >();
16871707 block.sync ();
@@ -1693,8 +1713,7 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
16931713 block.sync ();
16941714 page_produce_kv<true , num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
16951715 v_smem, &kv_smem_offset_w, paged_kv,
1696- chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, packed_page_iter_base, chunk_end,
1697- last_indptr);
1716+ chunk_start + (iter + 1 ) * 16 * num_warps_z * num_frags_z, kv_offset, chunk_end);
16981717 cp_async::commit_group ();
16991718 }
17001719 cp_async::wait_group<0 >();
@@ -1764,10 +1783,15 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
17641783 const uint_fastdiv group_size_fastdiv (group_size);
17651784 constexpr uint32_t num_frags_y = HEAD_DIM / 16 ;
17661785 WarpLayout warp_layout;
1767- if (qo_len * group_size > 64 && HEAD_DIM < 256 ) {
1786+ int64_t unpacked_qo_len = qo_len * group_size;
1787+ if (unpacked_qo_len > 64 && HEAD_DIM < 256 ) {
17681788 warp_layout = WarpLayout::k4x1x2;
17691789 } else {
1770- warp_layout = WarpLayout::k4x1x1;
1790+ if (unpacked_qo_len > 16 ) {
1791+ warp_layout = WarpLayout::k4x1x1;
1792+ } else {
1793+ warp_layout = WarpLayout::k1x4x1;
1794+ }
17711795 }
17721796
17731797 DISPATCH_WARP_LAYOUT (warp_layout, WARP_LAYOUT, {
0 commit comments