Skip to content

Commit

Permalink
[Refactor] Formalize NHD/HND layout annotation (#85)
Browse files Browse the repository at this point in the history
We support two different layout annotations (NHD and HND, N: sequence
lenght, H: number of heads, D: head dimension) for QKV matrices. HND
layout is beneficial when $D$ is small or the data type bit-width is
small (in which case the consecutive length $D$ vector in NHD layout can
not fulfill a cacheline).

However, HND layout is not useful for query matrix as we only access
query once and pin their value in register/smem. The natural layout of
the query matrix is NHD which is the direct output of $x \cdot W_q$, and
KV-Cache (either paged/ragged tensor) could have different layouts.

In this PR we formalize the use of NHD/HND layout annotations:
1. Query matrix always uses NHD layout, no need for any annotations.
2. KV-Cache can have either NHD or HND layout, user should specify their
layout.
3. Layout annotations could be hidden from users in Python APIs because
they can be inferred from shape.

This PR also adds support for NHD paged-kv cache (we only support HND
paged-kv cache before this PR).
  • Loading branch information
yzh119 authored Jan 24, 2024
1 parent d10f082 commit 9f49803
Show file tree
Hide file tree
Showing 28 changed files with 720 additions and 527 deletions.
103 changes: 52 additions & 51 deletions include/flashinfer/decode.cuh

Large diffs are not rendered by default.

58 changes: 31 additions & 27 deletions include/flashinfer/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,16 @@ class BatchDecodeHandler {
}
}

template <PageStorage page_storage, typename DTypeIn, typename DTypeOut, typename IdType>
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BeginForward(IdType* indptr, IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, RotaryMode rotary_mode) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, DTypeIn, DTypeOut, IdType>;
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, kv_layout, DTypeIn, DTypeOut,
IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
Expand Down Expand Up @@ -234,6 +236,7 @@ class BatchPrefillHandler {
* \brief Wrapper of BatchDecodeWithPagedKVCache function, and caches the temporary buffer
* for cooperative kernels.
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam kv_layout The layout of last 3 dimensions in KV-Cache
* \tparam DTypeIn The data type of input tensor.
* \tparam DTypeOut The data type of output tensor.
* \tparam IdType The data type of index tensor.
Expand All @@ -250,14 +253,14 @@ class BatchPrefillHandler {
* \note This wrapper function should be only called after we call BeginForward function in the
* BatchDecodeHandler.
*/
template <PageStorage page_storage, typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTypeIn* q,
paged_kv_t<page_storage, DTypeIn, IdType> paged_kv,
DTypeOut* o, float* lse, uint32_t num_qo_heads,
RotaryMode rotary_mode = RotaryMode::kNone,
float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, DTypeIn, IdType> new_paged_kv = paged_kv;
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
BatchDecodeHandler* handler, DTypeIn* q,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp = handler->GetTempFloatBuffer<DTypeOut>();
if (handler->IsForwardStarted()) {
Expand All @@ -278,17 +281,17 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(BatchDecodeHandler* handler, DTyp
"BatchDecodeWithPagedKVCacheWrapper()";
throw std::runtime_error(err_msg.str());
}
return BatchDecodeWithPagedKVCache<page_storage, DTypeIn, DTypeOut, IdType>(
return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
q, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, rope_scale,
rope_theta, stream);
}

template <PageStorage page_storage, uint32_t GROUP_SIZE, uint32_t HEAD_DIM, RotaryMode ROTARY_MODE,
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut,
typename IdType>
template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
paged_kv_t<page_storage, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
float* tmp = nullptr;
Expand All @@ -312,13 +315,13 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
num_frags_x, NUM_FRAGS_X, {SWITCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, {
if constexpr (PAGE_SIZE == 0) {
return BatchPrefillWithPagedKVCacheFallbackDispatched<
page_storage, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION,
CAUSAL, DTypeIn, DTypeOut, IdType>(q, request_indices, tile_indices, qo_indptr,
paged_kv, o, tmp, lse, num_qo_tiles, rope_scale,
rope_theta, stream);
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
rope_scale, rope_theta, stream);
} else {
return BatchPrefillWithPagedKVCacheDispatched<
page_storage, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, paged_kv, o, tmp, lse, num_qo_tiles,
rope_scale, rope_theta, stream);
Expand All @@ -327,10 +330,11 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
return cudaSuccess;
}

template <PageStorage page_storage, typename DTypeIn, typename DTypeOut, typename IdType>
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr,
paged_kv_t<page_storage, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
Expand All @@ -346,15 +350,15 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
{SWITCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, paged_kv, o, lse, num_qo_heads,
rope_scale, rope_theta, stream);
})})})})});
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout LAYOUT, RotaryMode ROTARY_MODE,
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMode ROTARY_MODE,
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
Expand All @@ -380,7 +384,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
}

SWITCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, {
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, LAYOUT,
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, o, tmp, lse, batch_size,
Expand All @@ -397,7 +401,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, const float rope_scale = 1.f,
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
constexpr QKVLayout LAYOUT = QKVLayout::kNHD;
constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD;
SWITCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{SWITCH_HEAD_DIM(
Expand All @@ -408,7 +412,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
{SWITCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE,
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, o, lse, batch_size,
num_kv_heads, rope_scale, rope_theta, stream);
Expand Down
24 changes: 12 additions & 12 deletions include/flashinfer/layout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ __host__ __device__ __forceinline__ uint32_t get_h_stride_impl(uint32_t seq_len)
return layout == QKVLayout::kNHD ? head_dim : seq_len * head_dim;
}

template <QKVLayout layout, uint32_t group_size, uint32_t head_dim>
template <QKVLayout kv_layout, uint32_t group_size, uint32_t head_dim>
struct tensor_info_t {
uint32_t qo_len;
uint32_t kv_len;
Expand All @@ -69,40 +69,40 @@ struct tensor_info_t {
__host__ __device__ __forceinline__ size_t get_qo_elem_offset(uint32_t qo_idx,
uint32_t qo_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl<layout, head_dim>(qo_idx, qo_head_idx, feat_idx, qo_len,
get_num_qo_heads());
return get_elem_offset_impl<QKVLayout::kNHD, head_dim>(qo_idx, qo_head_idx, feat_idx, qo_len,
get_num_qo_heads());
}

__host__ __device__ __forceinline__ size_t get_kv_elem_offset(uint32_t kv_idx,
uint32_t kv_head_idx,
uint32_t feat_idx) const {
return get_elem_offset_impl<layout, head_dim>(kv_idx, kv_head_idx, feat_idx, kv_len,
num_kv_heads);
return get_elem_offset_impl<kv_layout, head_dim>(kv_idx, kv_head_idx, feat_idx, kv_len,
num_kv_heads);
}

__host__ __device__ __forceinline__ uint32_t get_qo_n_stride() const {
return get_n_stride_impl<layout, head_dim>(get_num_qo_heads());
return get_n_stride_impl<QKVLayout::kNHD, head_dim>(get_num_qo_heads());
}

__host__ __device__ __forceinline__ uint32_t get_kv_n_stride() const {
return get_n_stride_impl<layout, head_dim>(num_kv_heads);
return get_n_stride_impl<kv_layout, head_dim>(num_kv_heads);
}

__host__ __device__ __forceinline__ uint32_t get_qo_h_stride() const {
return get_h_stride_impl<layout, head_dim>(qo_len);
return get_h_stride_impl<QKVLayout::kNHD, head_dim>(qo_len);
}

__host__ __device__ __forceinline__ uint32_t get_kv_h_stride() const {
return get_h_stride_impl<layout, head_dim>(kv_len);
return get_h_stride_impl<kv_layout, head_dim>(kv_len);
}
};

/*!
* \brief Convert QKVLayout to string
* \param qkv_layout The QKVLayout to convert
* \param layout The QKVLayout to convert
*/
inline std::string QKVLayoutToString(const QKVLayout& qkv_layout) {
switch (qkv_layout) {
inline std::string QKVLayoutToString(const QKVLayout& layout) {
switch (layout) {
case QKVLayout::kNHD:
return "NHD";
case QKVLayout::kHND:
Expand Down
Loading

0 comments on commit 9f49803

Please sign in to comment.