From b91ce9a259d3af4bba14c05b968fdf24373545d6 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 2 Dec 2024 15:26:29 +0800 Subject: [PATCH] Fix gemma2 accuracy through the correct softcapping logic (#2842) * Fix gemma2 accuracy through the correct softcapping logic * remove debugging codes --- lmdeploy/pytorch/kernels/cuda/flashattention.py | 17 ++++++++++------- lmdeploy/pytorch/kernels/cuda/pagedattention.py | 6 ++++-- lmdeploy/pytorch/models/gemma.py | 9 ++++++++- 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/flashattention.py b/lmdeploy/pytorch/kernels/cuda/flashattention.py index 7521a3e2bb..34a11ae030 100644 --- a/lmdeploy/pytorch/kernels/cuda/flashattention.py +++ b/lmdeploy/pytorch/kernels/cuda/flashattention.py @@ -49,7 +49,7 @@ def softcapping(qk, logit_softcapping: tl.constexpr): @triton.jit def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, - loop_start, loop_end, qk_scale, history_mask, + loop_start, loop_end, sm_scale, history_mask, kv_min_loc, causal_mask: tl.constexpr, window_size: tl.constexpr, logit_softcapping: tl.constexpr, BLOCK_N: tl.constexpr, @@ -71,8 +71,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, qk += tl.dot(q1, k1) if causal_mask: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) qk_mask = (history_mask[:, None]) >= (start_n + offs_n[None, :]) if window_size > 0: qk_mask = qk_mask and ( @@ -85,8 +86,9 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] elif window_size > 0: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) qk_mask = ((start_n + offs_n[None, :]) >= kv_min_loc[:, None]) qk = tl.where( qk_mask, @@ -96,11 +98,13 @@ def _prefill_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, q1, k1_ptrs, m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] elif logit_softcapping > 0: - qk *= qk_scale + qk *= sm_scale qk = softcapping(qk, logit_softcapping) + qk = qk * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_i_new[:, None] else: + qk_scale = sm_scale * tl_log2(math.e) m_i_new = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) qk = qk * qk_scale - m_i_new[:, None] @@ -256,7 +260,6 @@ def _flash_prefill_fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) - qk_scale = sm_scale * tl_log2(math.e) history_mask = history_len + start_m * BLOCK_M + tl.arange(0, BLOCK_M) loop_end = (history_len + start_m * BLOCK_M) // BLOCK_N * BLOCK_N @@ -270,7 +273,7 @@ def _flash_prefill_fwd_kernel( k1_ptrs, loop_start, loop_end, - qk_scale, + sm_scale, history_mask, kv_min_loc, causal_mask=False, @@ -291,7 +294,7 @@ def _flash_prefill_fwd_kernel( k1_ptrs, loop_start, loop_end, - qk_scale, + sm_scale, history_mask, kv_min_loc, causal_mask=True, diff --git a/lmdeploy/pytorch/kernels/cuda/pagedattention.py b/lmdeploy/pytorch/kernels/cuda/pagedattention.py index bbd6d3cf78..fe44ca4344 100644 --- a/lmdeploy/pytorch/kernels/cuda/pagedattention.py +++ b/lmdeploy/pytorch/kernels/cuda/pagedattention.py @@ -205,11 +205,12 @@ def _fwd_grouped_split_kernel( qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) - qk *= sm_scale * tl_log2(math.e) + qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping qk = tanh(qk) qk = qk * logit_softcapping + qk = qk * tl_log2(math.e) # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: qk_mask = history_len >= (start_n + offs_n) @@ -491,11 +492,12 @@ def _fwd_grouped_split_quant_kernel( qk += tl.dot(q, k) if BLOCK_DMODEL1 != 0: qk += tl.dot(q1, k1) - qk *= sm_scale * tl_log2(math.e) + qk *= sm_scale if logit_softcapping > 0.0: qk = qk / logit_softcapping qk = tanh(qk) qk = qk * logit_softcapping + qk = qk * tl_log2(math.e) # NOTE: inf - inf = nan, and nan will leads to error if start_n + BLOCK_N > history_len or window_size > 0: qk_mask = history_len >= (start_n + offs_n) diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py index 450767bda3..ca36f15651 100644 --- a/lmdeploy/pytorch/models/gemma.py +++ b/lmdeploy/pytorch/models/gemma.py @@ -383,6 +383,8 @@ def __init__(self, bias=False, dtype=dtype, device=device) + self.final_logit_softcapping = getattr(config, + 'final_logit_softcapping', None) def forward( self, @@ -405,7 +407,12 @@ def forward( def get_logits(self, hidden_states: torch.Tensor): """compute logits of the model output.""" - return self.lm_head(hidden_states) + logits = self.lm_head(hidden_states) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + return logits def get_input_embeddings(self): """get input embeddings."""