Skip to content

Commit cf77d96

Browse files
authored
refactor: simplify kernel interface (#312)
We don't need to separate between `tmp_v`/`o` and `tmp_s`/`lse` in kernel arguments
1 parent 3d43dc9 commit cf77d96

13 files changed

+52
-72
lines changed

include/flashinfer/attention/decode.cuh

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include <cooperative_groups.h>
1919
#include <cuda_bf16.h>
2020
#include <cuda_fp16.h>
21+
22+
#include <cstddef>
2123
#ifdef FLASHINFER_ENABLE_FP8
2224
#include <cuda_fp8.h>
2325
#endif
@@ -195,7 +197,6 @@ __device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, f
195197
* \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache
196198
* \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache
197199
* \param o [num_qo_heads, head_dim] The output matrix
198-
* \param tmp Used-allocated temporary buffer
199200
* \param info The tensor info of k/v matrices
200201
* \param sm_scale A float indicates the scale applied to pre-softmax logits
201202
* \param head_dim A integer indicates the head dimension
@@ -212,7 +213,7 @@ template <LogitsPostHook logits_post_hook, QKVLayout kv_layout, bool partition_k
212213
typename DTypeKV, typename DTypeOut>
213214
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
214215
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
215-
DTypeOut* __restrict__ tmp,
216+
float* __restrict__ lse,
216217
tensor_info_t<kv_layout, bdx * vec_size> info,
217218
float sm_scale, float rope_rcp_scale,
218219
float rope_rcp_theta, uint32_t kv_chunk_size) {
@@ -224,7 +225,6 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
224225
uint32_t kv_head_idx = blockIdx.y;
225226
uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y;
226227
uint32_t kv_chunk_idx = blockIdx.x;
227-
uint32_t num_kv_chunks = gridDim.x;
228228
uint32_t num_qo_heads = info.num_qo_heads;
229229
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
230230
uint32_t seq_len = info.kv_len;
@@ -350,14 +350,9 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
350350
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);
351351
st_local.normalize();
352352

353-
if constexpr (partition_kv) {
354-
// update tmp buffer
355-
st_local.o.cast_store(tmp + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim +
356-
tx * vec_size);
357-
float* tmp_lse = (float*)(tmp + num_kv_chunks * num_qo_heads * head_dim);
358-
tmp_lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
359-
} else {
360-
st_local.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
353+
st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
354+
if (lse != nullptr) {
355+
lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse();
361356
}
362357
}
363358

@@ -528,9 +523,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
528523
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
529524
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
530525
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
531-
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
532-
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
533-
float rope_rcp_theta) {
526+
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
527+
float rope_rcp_scale, float rope_rcp_theta) {
534528
auto block = cg::this_thread_block();
535529
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
536530

@@ -710,15 +704,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
710704
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
711705
st.normalize();
712706

713-
if constexpr (partition_kv) {
714-
st.o.cast_store(tmp_v + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
715-
tmp_s[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
716-
} else {
717-
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
718-
// write lse
719-
if (lse != nullptr) {
720-
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
721-
}
707+
st.o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size);
708+
// write lse
709+
if (lse != nullptr) {
710+
lse[batch_idx * num_qo_heads + qo_head_idx] = st.get_lse();
722711
}
723712
}
724713

@@ -800,11 +789,12 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
800789

801790
dim3 nblks = dim3(1, num_kv_heads);
802791
dim3 nthrs = dim3(bdx, bdy, bdz);
792+
float* lse = nullptr;
803793
void* args[] = {(void*)&q,
804794
(void*)&k,
805795
(void*)&v,
806796
(void*)&o,
807-
(void*)&tmp,
797+
(void*)&lse,
808798
(void*)&info,
809799
(void*)&sm_scale,
810800
(void*)&rope_rcp_scale,
@@ -838,19 +828,20 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
838828
throw std::runtime_error(err_msg.str());
839829
}
840830
dim3 nthrs = dim3(bdx, bdy, bdz);
831+
float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM);
841832
void* args[] = {(void*)&q,
842833
(void*)&k,
843834
(void*)&v,
844-
(void*)&o,
845835
(void*)&tmp,
836+
(void*)&tmp_lse,
846837
(void*)&info,
847838
(void*)&sm_scale,
848839
(void*)&rope_rcp_scale,
849840
(void*)&rope_rcp_theta,
850841
(void*)&kv_chunk_size};
851842
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
852-
FLASHINFER_CUDA_CALL(MergeStates(tmp, (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM), o,
853-
nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
843+
FLASHINFER_CUDA_CALL(
844+
MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream));
854845
}
855846
});
856847
return cudaSuccess;
@@ -897,8 +888,6 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
897888
(void*)&paged_kv,
898889
(void*)&kv_partition_info,
899890
(void*)&o,
900-
(void*)&tmp_v,
901-
(void*)&tmp_s,
902891
(void*)&lse,
903892
(void*)&block_valid_mask,
904893
(void*)&sm_scale,
@@ -918,10 +907,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
918907
(void*)&q_offset,
919908
(void*)&paged_kv,
920909
(void*)&kv_partition_info,
921-
(void*)&o,
922910
(void*)&tmp_v,
923911
(void*)&tmp_s,
924-
(void*)&lse,
925912
(void*)&block_valid_mask,
926913
(void*)&sm_scale,
927914
(void*)&rope_rcp_scale,

include/flashinfer/attention/handler.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
3838
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
3939
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
4040
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
41-
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
42-
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
43-
float rope_rcp_theta);
41+
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
42+
float rope_rcp_scale, float rope_rcp_theta);
4443

4544
/*!
4645
* \brief Compute the maximum number of pages per batch and the new batch size

include/flashinfer/attention/prefill.cuh

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -889,14 +889,11 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, MaskMode mask_mode
889889
QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, uint32_t num_frags_x,
890890
uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps, typename DTypeIn,
891891
typename DTypeQKAccum, typename DTypeOut>
892-
__global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k,
893-
DTypeIn* __restrict__ v,
894-
uint8_t* __restrict__ custom_mask,
895-
DTypeOut* __restrict__ o, void* __restrict__ tmp,
896-
float* __restrict__ lse, const uint32_t qo_len,
897-
const uint32_t kv_len, const uint_fastdiv group_size,
898-
float sm_scale, const float log2_rope_rcp_scale,
899-
const float log2_rope_rcp_theta) {
892+
__global__ void SinglePrefillWithKVCacheKernel(
893+
DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v,
894+
uint8_t* __restrict__ custom_mask, DTypeOut* __restrict__ o, float* __restrict__ lse,
895+
const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, float sm_scale,
896+
const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) {
900897
static_assert(sizeof(DTypeIn) == 2);
901898
static_assert(sizeof(DTypeOut) == 2);
902899
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
@@ -940,7 +937,7 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
940937
DTypeIn* q_ptr_base = q + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
941938
(tx % 8) * num_elems_per_128b<DTypeIn>());
942939
DTypeOut* o_ptr_base =
943-
partition_kv ? ((DTypeOut*)tmp) + chunk_idx * num_qo_heads * head_dim +
940+
partition_kv ? o + chunk_idx * num_qo_heads * head_dim +
944941
qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
945942
(tx % 8) * num_elems_per_128b<DTypeOut>())
946943
: o + qkv_info.get_qo_elem_offset(0, kv_head_idx * group_size,
@@ -1087,9 +1084,7 @@ __global__ void SinglePrefillWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
10871084
const uint32_t qo_idx = q;
10881085
if (qo_idx < qo_len) {
10891086
if constexpr (partition_kv) {
1090-
float* tmp_lse =
1091-
(float*)(((DTypeOut*)tmp) + qo_len * num_chunks * num_qo_heads * head_dim);
1092-
tmp_lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] =
1087+
lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] =
10931088
math::ptx_log2(d[fx][j]) + float(m[fx][j]);
10941089
} else {
10951090
lse[qo_idx * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]);
@@ -1534,7 +1529,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
15341529
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
15351530
typename DTypeIn, typename DTypeOut>
15361531
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
1537-
uint8_t* custom_mask, DTypeOut* o, float* tmp,
1532+
uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
15381533
float* lse, uint32_t num_qo_heads,
15391534
uint32_t num_kv_heads, uint32_t qo_len,
15401535
uint32_t kv_len, float sm_scale, float rope_scale,
@@ -1625,7 +1620,6 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
16251620
(void*)&v,
16261621
(void*)&custom_mask,
16271622
(void*)&o,
1628-
(void*)&tmp,
16291623
(void*)&lse,
16301624
(void*)&qo_len,
16311625
(void*)&kv_len,
@@ -1641,13 +1635,13 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
16411635
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
16421636
} else {
16431637
// Use cooperative groups to increase occupancy
1638+
float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM);
16441639
void* args[] = {(void*)&q,
16451640
(void*)&k,
16461641
(void*)&v,
16471642
(void*)&custom_mask,
1648-
(void*)&o,
16491643
(void*)&tmp,
1650-
(void*)&lse,
1644+
(void*)&tmp_lse,
16511645
(void*)&qo_len,
16521646
(void*)&kv_len,
16531647
(void*)&group_size_fastdiv,
@@ -1658,10 +1652,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
16581652
dim3 nthrs(32, num_warps);
16591653
FLASHINFER_CUDA_CALL(
16601654
cudaLaunchKernel((void*)partition_kv_kernel, nblks, nthrs, args, smem_size, stream));
1661-
FLASHINFER_CUDA_CALL(MergeStates(
1662-
(DTypeOut*)tmp,
1663-
(float*)(((DTypeOut*)tmp) + num_chunks * qo_len * num_qo_heads * HEAD_DIM), o, lse,
1664-
num_chunks, qo_len, num_qo_heads, HEAD_DIM, stream));
1655+
FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads,
1656+
HEAD_DIM, stream));
16651657
}
16661658
}
16671659
})

include/flashinfer/prefill_attention_decl.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
3232
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
3333
typename DTypeIn, typename DTypeOut>
3434
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v,
35-
uint8_t* custom_mask, DTypeOut* o, float* tmp,
35+
uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
3636
float* lse, uint32_t num_qo_heads,
3737
uint32_t num_kv_heads, uint32_t qo_len,
3838
uint32_t kv_len, float sm_scale, float rope_scale,

python/csrc/single_prefill.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache(
7878
static_cast<c_type*>(k.data_ptr()),
7979
static_cast<c_type*>(v.data_ptr()),
8080
/*custom_mask=*/nullptr, static_cast<c_type*>(o.data_ptr()),
81-
static_cast<float*>(tmp.data_ptr()),
81+
static_cast<c_type*>(tmp.data_ptr()),
8282
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
8383
num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale,
8484
rope_theta, torch_current_stream);
@@ -159,7 +159,7 @@ std::vector<torch::Tensor> single_prefill_with_kv_cache_custom_mask(
159159
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
160160
static_cast<c_type*>(v.data_ptr()),
161161
static_cast<uint8_t*>(packed_custom_mask.data_ptr()),
162-
static_cast<c_type*>(o.data_ptr()), static_cast<float*>(tmp.data_ptr()),
162+
static_cast<c_type*>(o.data_ptr()), static_cast<c_type*>(tmp.data_ptr()),
163163
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
164164
num_qo_heads, num_kv_heads, qo_len, kv_len, sm_scale, rope_scale,
165165
rope_theta, torch_current_stream);

python/generate_single_prefill_inst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def get_cu_file_str(
4343
4444
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {logits_hook}, {kv_layout}, {pos_encoding_mode}, {allow_fp16_qk_reduction}, {mask_mode}, {dtype_in}, {dtype_out}>(
4545
{dtype_in}* q, {dtype_in}* k, {dtype_in}* v, uint8_t* custom_mask, {dtype_out}* o,
46-
float* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
46+
{dtype_out}* tmp, float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
4747
float sm_scale, float rope_scale,
4848
float rope_theta, cudaStream_t stream);
4949

src/bench_cascade.cu

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) {
9696
if (use_cascade) {
9797
thrust::device_vector<T> shared_k_d(shared_k_h), shared_v_d(shared_v_h),
9898
o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size());
99-
thrust::device_vector<float> tmp_0_d(8 * 1024 * 1024),
100-
lse_cascade_0_d(batch_size * num_qo_heads), lse_cascade_1_d(batch_size * num_qo_heads);
99+
thrust::device_vector<T> tmp_0_d(16 * 1024 * 1024);
100+
thrust::device_vector<float> lse_cascade_0_d(batch_size * num_qo_heads),
101+
lse_cascade_1_d(batch_size * num_qo_heads);
101102
thrust::device_vector<int32_t> kv_indptr_unique_d(kv_indptr_unique_h),
102103
kv_indices_unique_d(kv_indices_unique_h),
103104
kv_last_page_len_unique_d(kv_last_page_len_unique_h);
@@ -231,8 +232,8 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
231232
if (use_cascade) {
232233
thrust::device_vector<T> shared_k_d(shared_k_h), shared_v_d(shared_v_h),
233234
o_cascade_0_d(q_h.size()), o_cascade_1_d(q_h.size());
234-
thrust::device_vector<float> tmp_0_d(8 * 1024 * 1024),
235-
lse_cascade_0_d((batch_size * qo_append_length) * num_qo_heads),
235+
thrust::device_vector<T> tmp_0_d(8 * 1024 * 1024);
236+
thrust::device_vector<float> lse_cascade_0_d((batch_size * qo_append_length) * num_qo_heads),
236237
lse_cascade_1_d((batch_size * qo_append_length) * num_qo_heads);
237238
thrust::device_vector<int32_t> kv_indptr_unique_d(kv_indptr_unique_h),
238239
kv_indices_unique_d(kv_indices_unique_h),

src/bench_single_decode.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) {
7676
thrust::device_vector<dtype_in> K(seq_len * num_kv_heads * head_dim);
7777
thrust::device_vector<dtype_in> V(seq_len * num_kv_heads * head_dim);
7878
thrust::device_vector<dtype_out> O(num_qo_heads * head_dim);
79-
thrust::device_vector<float> tmp(8 * 1024 * 1024);
79+
thrust::device_vector<dtype_out> tmp(16 * 1024 * 1024);
8080

8181
// Provide throughput information:
8282
state.add_global_memory_reads<dtype_in>(

src/bench_single_prefill.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ void bench_flashinfer_single_prefill(nvbench::state& state) {
5050
thrust::device_vector<dtype_in> V(kv_len * num_kv_heads * head_dim);
5151
thrust::device_vector<uint8_t> mask(ceil_div(qo_len * kv_len, 8));
5252
thrust::device_vector<dtype_out> O(qo_len * num_qo_heads * head_dim);
53-
thrust::device_vector<float> tmp(8 * 1024 * 1024);
53+
thrust::device_vector<dtype_out> tmp(16 * 1024 * 1024);
5454

5555
// Provide throughput information:
5656
state.add_global_memory_reads<dtype_in>(

src/flashinfer_ops.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ namespace flashinfer {
2525

2626
template <typename DTypeIn, typename DTypeOut>
2727
cudaError_t SinglePrefillWithKVCacheCustomMask(
28-
DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, float* tmp, float* lse,
29-
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
28+
DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeOut* o, DTypeOut* tmp,
29+
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len,
3030
uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD,
3131
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
3232
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
@@ -72,7 +72,7 @@ cudaError_t SinglePrefillWithKVCacheCustomMask(
7272
* \return status Indicates whether CUDA calls are successful
7373
*/
7474
template <typename DTypeIn, typename DTypeOut>
75-
cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp,
75+
cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp,
7676
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads,
7777
uint32_t qo_len, uint32_t kv_len, uint32_t head_dim,
7878
bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,

0 commit comments

Comments
 (0)