Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def _flash_decoding_fwd_kernel(
m = tl.max(S_ij, 0)
S_ij -= m
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
l_i = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l
acc = acc / l_i

offsets_mid_o = (
cur_token_idx * stride_mid_ot
Expand All @@ -126,8 +126,8 @@ def _flash_decoding_fwd_kernel(
offsets_mid_o_lse = (
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
)
# logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))


# Triton 2.1.0
Expand Down Expand Up @@ -234,10 +234,10 @@ def _alibi_flash_decoding_fwd_kernel(
m = tl.max(S_ij, 0)
S_ij -= m
p_ij_hat = tl.exp(S_ij)
l = tl.sum(p_ij_hat, 0)
l_i = tl.sum(p_ij_hat, 0)
p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)
acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)
acc = acc / l
acc = acc / l_i

offsets_mid_o = (
cur_token_idx * stride_mid_ot
Expand All @@ -249,8 +249,8 @@ def _alibi_flash_decoding_fwd_kernel(
offsets_mid_o_lse = (
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
)
# logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
# logsumexp l_i^(j) = m^(j) + log(l_i^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))


# Triton 2.1.0
Expand Down Expand Up @@ -290,7 +290,7 @@ def _flash_decoding_fwd_reduce_kernel(
# BLOCK_KV == BLOCK_SIZE for now. We might want to decrease the number of blocks of kv splitted.
kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV
m_i = float("-inf") # max logic
l = 0.0 # sum exp
l_i = 0.0 # sum exp
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)

offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
Expand All @@ -304,10 +304,10 @@ def _flash_decoding_fwd_reduce_kernel(
lse -= m_ij
exp_logic = tl.exp(lse)
acc += exp_logic * mid_o_block
l = scale * l + exp_logic
l_i = scale * l_i + exp_logic
m_i = m_ij

acc = acc / l
acc = acc / l_i
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
tl.store(O + offsets_O, acc.to(O.type.element_ty))
return
Expand Down