Skip to content

Commit

Permalink
[Performance] Fix the fp8 decode kernel performance degradation issue (
Browse files Browse the repository at this point in the history
…#84)

The fp8 decode kernel performance degrades because of the changes in
previous commits
2a3d6d0
. This PR fixes the degradation.
  • Loading branch information
yzh119 authored Jan 21, 2024
1 parent 69e1d03 commit b498701
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions include/flashinfer/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t&
constexpr uint32_t num_threads =
std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE;
constexpr uint32_t tile_size_per_bdx =
GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 8U) : 1U;
const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz *
head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
Expand Down Expand Up @@ -832,7 +833,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
tensor_info_t<QKV_LAYOUT, GROUP_SIZE, HEAD_DIM> info(1, seq_len, num_kv_heads);
constexpr uint32_t tile_size_per_bdx = 8U / GROUP_SIZE;
constexpr uint32_t tile_size_per_bdx =
GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 8U) : 1U;
const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz *
head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
Expand Down Expand Up @@ -1064,7 +1066,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimation(
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? 4U : 1U;
constexpr uint32_t tile_size_per_bdx =
GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeIn) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*),
Expand Down Expand Up @@ -1134,7 +1137,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? 4U : 1U;
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));
Expand Down

0 comments on commit b498701

Please sign in to comment.