Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm bi

# The following configurations can impact the binary
# size of the generated library
flashinfer_option(FLASHINFER_GEN_PAGE_SIZES "Prefill page sizes to enable" 1 16 32)
flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
flashinfer_option(FLASHINFER_GEN_KV_LAYOUTS "KV layouts to enable" 0 1)
flashinfer_option(FLASHINFER_GEN_LOGITS_POST_HOOKS "Logits post hooks" 0 1)
Expand Down Expand Up @@ -80,7 +79,6 @@ if(FLASHINFER_ENABLE_BF16)
endif(FLASHINFER_ENABLE_BF16)

# generate kernel inst
set (PAGE_SIZES ${FLASHINFER_GEN_PAGE_SIZES})
set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS})
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS})
Expand All @@ -103,7 +101,6 @@ if(FLASHINFER_ENABLE_BF16)
endif(FLASHINFER_ENABLE_BF16)

# log options
message(STATUS "FLASHINFER_PAGE_SIZES=${PAGE_SIZES}")
message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}")
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
Expand All @@ -115,7 +112,7 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
add_custom_command(
OUTPUT ${dispatch_inc_file}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py
COMMENT "Generating additional source file ${generated_dispatch_inc}"
VERBATIM
Expand Down Expand Up @@ -249,33 +246,31 @@ foreach(head_dim IN LISTS HEAD_DIMS)
endforeach(head_dim)

# batch paged prefill kernel inst generation
foreach(page_size IN LISTS PAGE_SIZES)
foreach(head_dim IN LISTS HEAD_DIMS)
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
endforeach(head_dim)
endforeach(page_size)
foreach(head_dim IN LISTS HEAD_DIMS)
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
endforeach(head_dim)

# batch ragged prefill kernel inst generation
foreach(head_dim IN LISTS HEAD_DIMS)
Expand Down
1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
Expand Down
20 changes: 9 additions & 11 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
static_assert(num_stages_smem <= bdx);
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin + (((j * bdz + tz) * bdy + ty) * bdx + tx) / paged_kv.page_size,
kv_head_idx, (((j * bdz + tz) * bdy + ty) * bdx + tx) % paged_kv.page_size, 0, last_indptr);
uint32_t q, r;
paged_kv.page_size.divmod(((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
}
block.sync();

Expand Down Expand Up @@ -643,15 +644,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
if ((iter + num_stages_smem) % bdx == 0) {
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx),
q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx) /
paged_kv.page_size,
kv_head_idx,
((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx) %
paged_kv.page_size,
0, last_indptr);
cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
}
}
// compute qk
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * num_kv_heads >= num_sm) {
if (batch_size * num_kv_heads >= max_grid_size) {
tmp_size = 0;
new_batch_size = batch_size;
} else {
Expand Down
109 changes: 41 additions & 68 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
#endif
#include <cuda_runtime.h>

#include <optional>
#include <tuple>

#include "../cp_async.cuh"
#include "../fastdiv.cuh"
#include "../layout.cuh"
Expand Down Expand Up @@ -175,65 +172,41 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
}

template <bool produce_v, uint32_t page_size, uint32_t num_warps, uint32_t num_frags_y,
uint32_t num_frags_z, PageStorage page_storage, QKVLayout kv_layout, typename DType,
typename IdType>
template <bool produce_v, uint32_t num_warps, uint32_t num_frags_y, uint32_t num_frags_z,
PageStorage page_storage, QKVLayout kv_layout, typename DType, typename IdType>
__device__ __forceinline__ void page_produce_kv(
smem_t smem, uint32_t* smem_offset,
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
const uint32_t page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
constexpr SharedMemFillMode fill_mode =
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t kv_head_idx = blockIdx.z;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8;
if constexpr (page_size % 4 == 0) {
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4) % page_size + tx / 8;
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
} else {
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8) % page_size;
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(packed_page_iter_base + ty * 4 + tx / 8 + 4 * num_warps * i,
page_iter, entry_idx);
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
}

template <uint32_t num_frags_y>
Expand Down Expand Up @@ -1342,10 +1315,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
}
}

template <LogitsPostHook logits_post_hook, uint32_t page_size, MaskMode mask_mode,
PosEncodingMode pos_encoding_mode, uint32_t num_frags_x, uint32_t num_frags_y,
uint32_t num_frags_z, uint32_t num_warps, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeIn, typename DTypeQKAccum, typename DTypeOut, typename IdType>
template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode pos_encoding_mode,
uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeQKAccum,
typename DTypeOut, typename IdType>
__global__ void BatchPrefillWithPagedKVCacheKernel(
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
Expand Down Expand Up @@ -1448,12 +1421,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
smem_t::get_permuted_offset<channel_size_128b_in>(ty * 4 + tx / 8, tx % 8);
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

uint32_t page_iter_base = paged_kv.indptr[request_idx];
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size;
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
cp_async::commit_group();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
cp_async::commit_group();

const uint32_t num_iterations = ceil_div(
Expand Down Expand Up @@ -1508,10 +1481,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);

block.sync();
page_iter_base += 16 * num_frags_z / page_size;
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
last_indptr);
packed_page_iter_base += 16 * num_frags_z;
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
kv_len, last_indptr);
cp_async::commit_group();
cp_async::wait_group<1>();
block.sync();
Expand All @@ -1521,9 +1494,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
o_frag, d);

block.sync();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
last_indptr);
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
kv_len, last_indptr);
cp_async::commit_group();
}
cp_async::wait_group<0>();
Expand Down Expand Up @@ -1776,7 +1749,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
return cudaSuccess;
}

template <PageStorage page_storage, uint32_t num_frags_x, uint32_t PAGE_SIZE, uint32_t HEAD_DIM,
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
typename IdType>
Expand Down Expand Up @@ -1831,8 +1804,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
throw std::invalid_argument(err_msg.str());
} else {
auto kernel = BatchPrefillWithPagedKVCacheKernel<
LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y,
num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z,
num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
uint32_t smem_size =
(num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn);
FLASHINFER_CUDA_CALL(
Expand Down
1 change: 1 addition & 0 deletions include/flashinfer/fastdiv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
#ifndef FLASHINFER_FASTDIV_CUH_
#define FLASHINFER_FASTDIV_CUH_
#include <cstdint>

namespace flashinfer {

Expand Down
Loading