Skip to content

Commit

Permalink
feat: customize logits_soft_cap value (#339)
Browse files Browse the repository at this point in the history
This PR supports customized logits soft cap values. Different models
might use different logits soft cap values (e.g. Grok-1 uses 30 and
Gemma-2 uses 50).
  • Loading branch information
yzh119 authored Jun 28, 2024
1 parent 3afb6d3 commit a2498f5
Show file tree
Hide file tree
Showing 23 changed files with 377 additions and 282 deletions.
51 changes: 32 additions & 19 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
uint32_t iter_base, uint32_t iter_bound,
const int32_t q_offset, float alibi_slope, float* s,
state_t<vec_size>& st) {
state_t<vec_size>& st, const float logits_soft_cap) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
float m_prev = st.m;
#pragma unroll
Expand All @@ -100,7 +100,7 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
s[j] = apply_logits_post_hook<logits_post_hook>(s[j]);
s[j] = apply_logits_post_hook<logits_post_hook>(s[j], logits_soft_cap);
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
}
Expand Down Expand Up @@ -215,11 +215,13 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float sm_scale, float rope_rcp_scale,
float rope_rcp_theta, uint32_t kv_chunk_size) {
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta,
uint32_t kv_chunk_size) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -305,7 +307,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size,
seq_len - 1, alibi_slope, s, st_local);
seq_len - 1, alibi_slope, s, st_local, logits_soft_cap);
block.sync();
// load k
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -364,10 +366,11 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DType
DTypeOut* __restrict__ o,
float* __restrict__ lse,
tensor_info_t<kv_layout, bdx * vec_size> info,
float sm_scale, float rope_rcp_scale,
float rope_rcp_theta) {
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));

constexpr uint32_t head_dim = bdx * vec_size;
uint32_t kv_head_idx = blockIdx.y;
Expand Down Expand Up @@ -442,7 +445,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeQ* __restrict__ q, DType
block.sync();
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy>(
k_smem + (stage_idx * bdz + tz) * bdy * head_dim, stage_idx, q_vec, freq,
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local);
consumer_kv_idx_base, iter * bdy * bdz, seq_len, seq_len - 1, alibi_slope, s, st_local,
logits_soft_cap);
block.sync();
// load k
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
Expand Down Expand Up @@ -523,10 +527,11 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : 1.f / 30.f);
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));

constexpr uint32_t head_dim = bdx * vec_size;
const uint32_t batch_idx = blockIdx.x;
Expand Down Expand Up @@ -654,7 +659,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st);
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st,
logits_soft_cap);
block.sync();

#pragma unroll
Expand Down Expand Up @@ -760,7 +766,8 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t seq_len,
float sm_scale, float rope_scale, float rope_theta,
float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand Down Expand Up @@ -796,6 +803,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
Expand Down Expand Up @@ -835,6 +843,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta,
Expand All @@ -854,7 +863,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
const uint32_t num_kv_heads = paged_kv.num_heads;
Expand Down Expand Up @@ -890,6 +900,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
Expand All @@ -910,6 +921,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
Expand Down Expand Up @@ -949,9 +961,9 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
uint32_t num_kv_heads, float sm_scale,
float rope_scale, float rope_theta,
cudaStream_t stream) {
uint32_t num_kv_heads, float logits_soft_cap,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;

Expand Down Expand Up @@ -981,6 +993,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeK
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down
28 changes: 16 additions & 12 deletions include/flashinfer/attention/logits_post_hook.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,47 @@ namespace flashinfer {

enum class LogitsPostHook {
kNone = 0U,
kCap30 = 1U,
kSoftCap = 1U,
};

/*!
* \brief Grok's logits cap function
* \ref
* https://github.com/xai-org/grok-1/blob/7050ed204b8206bb8645c7b7bbef7252f79561b0/model.py#L864-L865
*/
__forceinline__ __device__ float logits_cap_30(float x) {
return (30 * math::log2e) * math::tanh(x);
__forceinline__ __device__ float logits_soft_cap_impl(float x, const float soft_cap) {
return (soft_cap * math::log2e) * math::tanh(x);
}

__forceinline__ __device__ half2 logits_cap_30(half2 x) {
return __hmul2(__float2half2_rn(30 * math::log2e), math::tanh(x));
__forceinline__ __device__ half2 logits_soft_cap_impl(half2 x, const float soft_cap) {
return __hmul2(__float2half2_rn(soft_cap * math::log2e), math::tanh(x));
}

template <LogitsPostHook mode, typename T>
__forceinline__ __device__ T apply_logits_post_hook(T x);
__forceinline__ __device__ T apply_logits_post_hook(T x, const float soft_cap);

template <>
__forceinline__ __device__ float apply_logits_post_hook<LogitsPostHook::kNone, float>(float x) {
__forceinline__ __device__ float apply_logits_post_hook<LogitsPostHook::kNone, float>(
float x, const float soft_cap) {
return x;
}

template <>
__forceinline__ __device__ float apply_logits_post_hook<LogitsPostHook::kCap30, float>(float x) {
return logits_cap_30(x);
__forceinline__ __device__ float apply_logits_post_hook<LogitsPostHook::kSoftCap, float>(
float x, const float soft_cap) {
return logits_soft_cap_impl(x, soft_cap);
}

template <>
__forceinline__ __device__ half2 apply_logits_post_hook<LogitsPostHook::kNone, half2>(half2 x) {
__forceinline__ __device__ half2
apply_logits_post_hook<LogitsPostHook::kNone, half2>(half2 x, const float soft_cap) {
return x;
}

template <>
__forceinline__ __device__ half2 apply_logits_post_hook<LogitsPostHook::kCap30, half2>(half2 x) {
return logits_cap_30(x);
__forceinline__ __device__ half2
apply_logits_post_hook<LogitsPostHook::kSoftCap, half2>(half2 x, const float soft_cap) {
return logits_soft_cap_impl(x, soft_cap);
}

} // namespace flashinfer
Expand Down
Loading

0 comments on commit a2498f5

Please sign in to comment.