Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
cp_async::commit_group();
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta();
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
Expand Down Expand Up @@ -554,7 +554,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
// load v tiles
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta();
DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_ptr_delta();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim +
tx * vec_size,
Expand Down
17 changes: 8 additions & 9 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offs
static_assert(num_frags_z * 4 % num_warps_x == 0);
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) {
DType* gptr = produce_v ? paged_kv.data + paged_kv.kv_offset_delta() + kv_offset[i]
: paged_kv.data + kv_offset[i];
DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i];
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
Expand Down Expand Up @@ -1608,8 +1607,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
page_iter, entry_idx);
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_k_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
page_produce_kv<false, num_warps_x, num_warps_z, num_frags_y, num_frags_z>(
Expand Down Expand Up @@ -1645,11 +1644,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage
paged_kv.page_size.divmod(
packed_page_iter_base + warp_idx * 4 + lane_idx / 8 + 4 * num_warps_x * num_warps_z * i,
page_iter, entry_idx);
kv_offset[i] = page_iter < last_indptr
? paged_kv.get_k_elem_offset(
__ldg(paged_kv.indices + page_iter), kv_head_idx, entry_idx,
(lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
kv_offset[i] =
page_iter < last_indptr
? paged_kv.get_elem_offset(__ldg(paged_kv.indices + page_iter), kv_head_idx,
entry_idx, (lane_idx % 8) * num_elems_per_128b<DTypeIn>())
: 0;
}
cp_async::wait_group<1>();
block.sync();
Expand Down
145 changes: 81 additions & 64 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,20 @@ struct paged_kv_t {
uint32_t num_heads;
uint32_t head_dim;
uint32_t batch_size;
uint32_t stride_page;
uint32_t stride_n;
uint32_t stride_h;

// The flattened key-value cache, used when page_storage == kIndices
// Internal layout:
// [max_num_pages, 2, num_heads, page_size, head_dim] if layout == HND
// [max_num_pages, 2, page_size, num_heads, head_dim] if layout == NHD
DType* data;
// [max_num_pages, num_heads, page_size, head_dim] if layout == HND
// [max_num_pages, page_size, num_heads, head_dim] if layout == NHD
DType* k_data;
DType* v_data;
// [nnz_pages] The page indices array, used when page_storage == kIndices
IdType* indices;
// [nnz_pages] The page pointers array, used when page_storage == kPointer
DType** ptrs;
DType** kv_ptrs;

// [batch_size + 1] The page indptr array, with the first element 0, the last element nnz_pages
IdType* indptr;
Expand All @@ -102,11 +104,13 @@ struct paged_kv_t {
page_size(0),
head_dim(0),
batch_size(0),
stride_page(0),
stride_n(0),
stride_h(0),
data(nullptr),
k_data(nullptr),
v_data(nullptr),
indices(nullptr),
ptrs(nullptr),
kv_ptrs(nullptr),
indptr(nullptr),
last_page_len(nullptr),
rope_pos_offset(nullptr) {}
Expand All @@ -118,26 +122,29 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param data The flattened key-value cache
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* data,
IdType* indices, IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
uint32_t batch_size, QKVLayout layout, DType* k_data,
DType* v_data, IdType* indices, IdType* indptr,
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
data(data),
k_data(k_data),
v_data(v_data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}
Expand All @@ -149,92 +156,100 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param ptrs The array of pointers to each active page
* \param kv_data The flattened key-value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType** ptrs,
IdType* indptr, IdType* last_page_len,
uint32_t batch_size, QKVLayout layout, DType* kv_data,
IdType* indices, IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
ptrs(ptrs),
k_data(kv_data),
v_data(kv_data + num_heads * page_size * head_dim),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = 2 * num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Compute the offset of k element in the allocated buffer.
* \param page_idx The page index
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
* \note This function should only be used when page_storage == kIndices
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param kv_ptrs The array of pointers to each active kv page
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ size_t get_k_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return page_idx * 2 * page_size * num_heads * head_dim + head_idx * stride_h +
entry_idx * stride_n + feat_idx;
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType** kv_ptrs,
IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
kv_ptrs(kv_ptrs),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = 2 * num_heads * page_size * head_dim;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Compute the offset of k element inside the page.
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_k_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return head_idx * stride_h + entry_idx * stride_n + feat_idx;
__host__ __device__ __forceinline__ int64_t kv_ptr_delta() const {
return page_storage == PageStorage::kPointer
? num_heads * page_size * head_dim
: (int64_t(v_data) - int64_t(k_data)) / sizeof(DType);
}

/*!
* \brief Compute the offset of v element in the allocated buffer.
* \brief Compute the offset of element in the allocated buffer.
* \param page_idx The page index
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
* \note This function should only be used when page_storage == kIndices
*/
__host__ __device__ __forceinline__ size_t get_v_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return (page_idx * 2 + 1) * page_size * num_heads * head_dim + head_idx * stride_h +
entry_idx * stride_n + feat_idx;
__host__ __device__ __forceinline__ size_t get_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return page_idx * stride_page + head_idx * stride_h + entry_idx * stride_n + feat_idx;
}

/*!
* \brief Compute the offset of v element inside the page.
* \brief Compute the offset of element inside the page.
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_v_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
__host__ __device__ __forceinline__ size_t get_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return head_idx * stride_h + entry_idx * stride_n + feat_idx;
}

__host__ __device__ __forceinline__ uint32_t kv_offset_delta() const {
return num_heads * page_size * head_dim;
}

__device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return ptrs[page_iter] + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

Expand All @@ -243,25 +258,26 @@ struct paged_kv_t {
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_k_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
return k_data;
}
} else {
if (page_iter < last_indptr) {
return ptrs[page_iter] + get_k_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return kv_ptrs[page_iter] + get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return *ptrs;
return *kv_ptrs;
}
}
}

__device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
if constexpr (page_storage == PageStorage::kIndices) {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return ptrs[page_iter] + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return (kv_ptrs[page_iter] + kv_ptr_delta()) +
get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
}
}

Expand All @@ -270,15 +286,16 @@ struct paged_kv_t {
IdType last_indptr) const {
if constexpr (page_storage == PageStorage::kIndices) {
if (page_iter < last_indptr) {
return data + get_v_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return data;
return v_data;
}
} else {
if (page_iter < last_indptr) {
return ptrs[page_iter] + get_v_elem_offset_in_page(head_idx, entry_idx, feat_idx);
return (kv_ptrs[page_iter] + kv_ptr_delta()) +
get_elem_offset_in_page(head_idx, entry_idx, feat_idx);
} else {
return *ptrs;
return *kv_ptrs;
}
}
}
Expand Down Expand Up @@ -312,7 +329,7 @@ __global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<page_storage, DType, I
uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size;

DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = k_ptr + paged_kv.kv_offset_delta();
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr, key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size);

Expand Down Expand Up @@ -355,7 +372,7 @@ __global__ void AppendPagedKVCachePrefillKernel(paged_kv_t<page_storage, DType,
uint32_t entry_idx = page_seq_idx % paged_kv.page_size;

DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = k_ptr + paged_kv.kv_offset_delta();
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr,
key + ((append_indptr[batch_idx] + j) * num_heads + head_idx) * head_dim + tx * vec_size);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ void append_paged_kv_cache(paged_kv_t<PageStorage::kIndices, T, IdxType> page_cp
for (size_t h = 0; h < num_heads; ++h) {
std::copy(ki.begin() + (j * num_heads + h) * head_dim,
ki.begin() + (j * num_heads + h + 1) * head_dim,
page_cpu.data + page_cpu.get_k_elem_offset(page_idx, h, entry_idx, 0));
page_cpu.k_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0));
std::copy(vi.begin() + (j * num_heads + h) * head_dim,
vi.begin() + (j * num_heads + h + 1) * head_dim,
page_cpu.data + page_cpu.get_v_elem_offset(page_idx, h, entry_idx, 0));
page_cpu.v_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0));
}
}
}
Expand Down
Loading