Skip to content

Commit

Permalink
Fix gemma2 accuracy through the correct softcapping logic (#2842)
Browse files Browse the repository at this point in the history
* Fix gemma2 accuracy through the correct softcapping logic

* remove debugging codes
  • Loading branch information
AllentDan authored Dec 2, 2024
1 parent 776677a commit b91ce9a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
17 changes: 10 additions & 7 deletions lmdeploy/pytorch/kernels/cuda/flashattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def softcapping(qk, logit_softcapping: tl.constexpr):

@triton.jit
def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
loop_start, loop_end, qk_scale, history_mask,
loop_start, loop_end, sm_scale, history_mask,
kv_min_loc, causal_mask: tl.constexpr,
window_size: tl.constexpr,
logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr,
Expand All @@ -71,8 +71,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
qk += tl.dot(q1, k1)

if causal_mask:
qk *= qk_scale
qk *= sm_scale
qk = softcapping(qk, logit_softcapping)
qk = qk * tl_log2(math.e)
qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :])
if window_size > 0:
qk_mask = qk_mask and (
Expand All @@ -85,8 +86,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_i_new[:, None]
elif window_size > 0:
qk *= qk_scale
qk *= sm_scale
qk = softcapping(qk, logit_softcapping)
qk = qk * tl_log2(math.e)
qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None])
qk = tl.where(
qk_mask,
Expand All @@ -96,11 +98,13 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs,
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_i_new[:, None]
elif logit_softcapping > 0:
qk *= qk_scale
qk *= sm_scale
qk = softcapping(qk, logit_softcapping)
qk = qk * tl_log2(math.e)
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_i_new[:, None]
else:
qk_scale = sm_scale * tl_log2(math.e)
m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_i_new[:, None]

Expand Down Expand Up @@ -256,7 +260,6 @@ def _flash_prefill_fwd_kernel(
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)

qk_scale = sm_scale * tl_log2(math.e)
history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M)

loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N
Expand All @@ -270,7 +273,7 @@ def _flash_prefill_fwd_kernel(
k1_ptrs,
loop_start,
loop_end,
qk_scale,
sm_scale,
history_mask,
kv_min_loc,
causal_mask=False,
Expand All @@ -291,7 +294,7 @@ def _flash_prefill_fwd_kernel(
k1_ptrs,
loop_start,
loop_end,
qk_scale,
sm_scale,
history_mask,
kv_min_loc,
causal_mask=True,
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,12 @@ def _fwd_grouped_split_kernel(
qk += tl.dot(q, k)
if BLOCK_DMODEL1 != 0:
qk += tl.dot(q1, k1)
qk *= sm_scale * tl_log2(math.e)
qk *= sm_scale
if logit_softcapping > 0.0:
qk = qk / logit_softcapping
qk = tanh(qk)
qk = qk * logit_softcapping
qk = qk * tl_log2(math.e)
# NOTE: inf - inf = nan, and nan will leads to error
if start_n + BLOCK_N > history_len or window_size > 0:
qk_mask = history_len >= (start_n + offs_n)
Expand Down Expand Up @@ -491,11 +492,12 @@ def _fwd_grouped_split_quant_kernel(
qk += tl.dot(q, k)
if BLOCK_DMODEL1 != 0:
qk += tl.dot(q1, k1)
qk *= sm_scale * tl_log2(math.e)
qk *= sm_scale
if logit_softcapping > 0.0:
qk = qk / logit_softcapping
qk = tanh(qk)
qk = qk * logit_softcapping
qk = qk * tl_log2(math.e)
# NOTE: inf - inf = nan, and nan will leads to error
if start_n + BLOCK_N > history_len or window_size > 0:
qk_mask = history_len >= (start_n + offs_n)
Expand Down
9 changes: 8 additions & 1 deletion lmdeploy/pytorch/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def __init__(self,
bias=False,
dtype=dtype,
device=device)
self.final_logit_softcapping = getattr(config,
'final_logit_softcapping', None)

def forward(
self,
Expand All @@ -405,7 +407,12 @@ def forward(

def get_logits(self, hidden_states: torch.Tensor):
"""compute logits of the model output."""
return self.lm_head(hidden_states)
logits = self.lm_head(hidden_states)
if self.final_logit_softcapping is not None:
logits = logits / self.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.final_logit_softcapping
return logits

def get_input_embeddings(self):
"""get input embeddings."""
Expand Down

0 comments on commit b91ce9a

Please sign in to comment.