1818#include < cooperative_groups.h>
1919#include < cuda_bf16.h>
2020#include < cuda_fp16.h>
21+
22+ #include < cstddef>
2123#ifdef FLASHINFER_ENABLE_FP8
2224#include < cuda_fp8.h>
2325#endif
@@ -195,7 +197,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
195197 * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache
196198 * \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache
197199 * \param o [num_qo_heads, head_dim] The output matrix
198- * \param tmp Used-allocated temporary buffer
199200 * \param info The tensor info of k/v matrices
200201 * \param sm_scale A float indicates the scale applied to pre-softmax logits
201202 * \param head_dim A integer indicates the head dimension
@@ -212,7 +213,7 @@ template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_k
212213 typename DTypeKV, typename DTypeOut>
213214__global__ void SingleDecodeWithKVCacheKernel (DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
214215 DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
215- DTypeOut * __restrict__ tmp ,
216+ float * __restrict__ lse ,
216217 tensor_info_t <kv_layout, bdx * vec_size> info,
217218 float sm_scale, float rope_rcp_scale,
218219 float rope_rcp_theta, uint32_t kv_chunk_size) {
@@ -224,7 +225,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
224225 uint32_t kv_head_idx = blockIdx .y ;
225226 uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx .y ;
226227 uint32_t kv_chunk_idx = blockIdx .x ;
227- uint32_t num_kv_chunks = gridDim .x ;
228228 uint32_t num_qo_heads = info.num_qo_heads ;
229229 const float alibi_slope = get_alibi_slope (qo_head_idx, num_qo_heads) * math::log2e;
230230 uint32_t seq_len = info.kv_len ;
@@ -350,14 +350,9 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
350350 sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast <float *>(smem), smem_md);
351351 st_local.normalize ();
352352
353- if constexpr (partition_kv) {
354- // update tmp buffer
355- st_local.o .cast_store (tmp + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim +
356- tx * vec_size);
357- float * tmp_lse = (float *)(tmp + num_kv_chunks * num_qo_heads * head_dim);
358- tmp_lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse ();
359- } else {
360- st_local.o .cast_store (o + info.get_qo_elem_offset (0 , qo_head_idx, tx * vec_size));
353+ st_local.o .cast_store (o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
354+ if (lse != nullptr ) {
355+ lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse ();
361356 }
362357}
363358
@@ -528,9 +523,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
528523 DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
529524 paged_kv_t <page_storage, kv_layout, DTypeKV, IdType> paged_kv,
530525 kv_partition_info_t <IdType> kv_partition_info, DTypeOut* __restrict__ o,
531- DTypeOut* __restrict__ tmp_v, float * __restrict__ tmp_s, float * __restrict__ lse,
532- bool * __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
533- float rope_rcp_theta) {
526+ float * __restrict__ lse, bool * __restrict__ block_valid_mask, float sm_scale,
527+ float rope_rcp_scale, float rope_rcp_theta) {
534528 auto block = cg::this_thread_block ();
535529 sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1 .f / 30 .f );
536530
@@ -710,15 +704,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
710704 sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast <float *>(smem), smem_md);
711705 st.normalize ();
712706
713- if constexpr (partition_kv) {
714- st.o .cast_store (tmp_v + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
715- tmp_s[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse ();
716- } else {
717- st.o .cast_store (o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
718- // write lse
719- if (lse != nullptr ) {
720- lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse ();
721- }
707+ st.o .cast_store (o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
708+ // write lse
709+ if (lse != nullptr ) {
710+ lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse ();
722711 }
723712}
724713
@@ -800,11 +789,12 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
800789
801790 dim3 nblks = dim3 (1 , num_kv_heads);
802791 dim3 nthrs = dim3 (bdx, bdy, bdz);
792+ float * lse = nullptr ;
803793 void * args[] = {(void *)&q,
804794 (void *)&k,
805795 (void *)&v,
806796 (void *)&o,
807- (void *)&tmp ,
797+ (void *)&lse ,
808798 (void *)&info,
809799 (void *)&sm_scale,
810800 (void *)&rope_rcp_scale,
@@ -838,19 +828,20 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
838828 throw std::runtime_error (err_msg.str ());
839829 }
840830 dim3 nthrs = dim3 (bdx, bdy, bdz);
831+ float * tmp_lse = (float *)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
841832 void * args[] = {(void *)&q,
842833 (void *)&k,
843834 (void *)&v,
844- (void *)&o,
845835 (void *)&tmp,
836+ (void *)&tmp_lse,
846837 (void *)&info,
847838 (void *)&sm_scale,
848839 (void *)&rope_rcp_scale,
849840 (void *)&rope_rcp_theta,
850841 (void *)&kv_chunk_size};
851842 FLASHINFER_CUDA_CALL (cudaLaunchKernel ((void *)kernel, nblks, nthrs, args, smem_size, stream));
852- FLASHINFER_CUDA_CALL (MergeStates (tmp, ( float *)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o,
853- nullptr , num_chunks, 1 , num_qo_heads, HEAD_DIM, stream));
843+ FLASHINFER_CUDA_CALL (
844+ MergeStates (tmp, tmp_lse, o, nullptr , num_chunks, 1 , num_qo_heads, HEAD_DIM, stream));
854845 }
855846 });
856847 return cudaSuccess;
@@ -897,8 +888,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
897888 (void *)&paged_kv,
898889 (void *)&kv_partition_info,
899890 (void *)&o,
900- (void *)&tmp_v,
901- (void *)&tmp_s,
902891 (void *)&lse,
903892 (void *)&block_valid_mask,
904893 (void *)&sm_scale,
@@ -918,10 +907,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
918907 (void *)&q_offset,
919908 (void *)&paged_kv,
920909 (void *)&kv_partition_info,
921- (void *)&o,
922910 (void *)&tmp_v,
923911 (void *)&tmp_s,
924- (void *)&lse,
925912 (void *)&block_valid_mask,
926913 (void *)&sm_scale,
927914 (void *)&rope_rcp_scale,
0 commit comments