Skip to content

Commit d1e3c63

Browse files
bzgoogleLumosis
authored andcommitted
fix scaling factor for DeepSeek (#521)
Signed-off-by: bzgoogle <beinuoz@google.com>
1 parent 6e6322e commit d1e3c63

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

tpu_commons/models/jax/common/attention/deepseek_v3_attention.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from dataclasses import dataclass
23
from typing import Any, Tuple
34

@@ -63,6 +64,7 @@ class MLA(nnx.Module):
6364
attention_chunk_size: int | None = None
6465
rope_input_ordering: str = "split"
6566
quant: Any | None = None
67+
rope_mscale_all_dim: float = 1.0
6668

6769
def __post_init__(self):
6870
self.N = self.num_attention_heads
@@ -73,6 +75,13 @@ def __post_init__(self):
7375

7476
assert self.N == self.K, "N and K must be equal for MLA"
7577

78+
if self.rope_scaling["factor"] <= 1.0:
79+
yarn_mscale = 1.0
80+
else:
81+
yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
82+
self.rope_scaling["factor"]) + 1.0
83+
self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2
84+
7685
self.rope = DeepseekScalingRotaryEmbedding(
7786
self.qk_rope_head_dim,
7887
self.rope_theta,
@@ -180,7 +189,6 @@ def __call__(self,
180189
# Concatenate the nope and rope queries.
181190
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
182191
# Multiple the query by scaling factor
183-
q_TNH = q_TNH * self.qk_head_dim**-0.5
184192
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)
185193

186194
with jax.named_scope("kv_proj"):
@@ -293,7 +301,7 @@ def attention(
293301
def _ragged_paged_attention(*args):
294302
return ragged_paged_attention(
295303
*args,
296-
sm_scale=q_TNH.shape[-1]**-0.5,
304+
sm_scale=self.scale,
297305
)
298306

299307
output_TNH, kv_cache = jax.jit(

0 commit comments

Comments
 (0)