1+ import math
12from dataclasses import dataclass
23from typing import Any , Tuple
34
@@ -63,6 +64,7 @@ class MLA(nnx.Module):
6364 attention_chunk_size : int | None = None
6465 rope_input_ordering : str = "split"
6566 quant : Any | None = None
67+ rope_mscale_all_dim : float = 1.0
6668
6769 def __post_init__ (self ):
6870 self .N = self .num_attention_heads
@@ -73,6 +75,13 @@ def __post_init__(self):
7375
7476 assert self .N == self .K , "N and K must be equal for MLA"
7577
78+ if self .rope_scaling ["factor" ] <= 1.0 :
79+ yarn_mscale = 1.0
80+ else :
81+ yarn_mscale = 0.1 * self .rope_mscale_all_dim * math .log (
82+ self .rope_scaling ["factor" ]) + 1.0
83+ self .scale = self .qk_head_dim ** - 0.5 * yarn_mscale ** 2
84+
7685 self .rope = DeepseekScalingRotaryEmbedding (
7786 self .qk_rope_head_dim ,
7887 self .rope_theta ,
@@ -180,7 +189,6 @@ def __call__(self,
180189 # Concatenate the nope and rope queries.
181190 q_TNH = jnp .concatenate ([q_nope_TNH , q_rope_TNH ], axis = - 1 )
182191 # Multiple the query by scaling factor
183- q_TNH = q_TNH * self .qk_head_dim ** - 0.5
184192 q_TNH = nnx .with_sharding_constraint (q_TNH , self .query_tnh )
185193
186194 with jax .named_scope ("kv_proj" ):
@@ -293,7 +301,7 @@ def attention(
293301 def _ragged_paged_attention (* args ):
294302 return ragged_paged_attention (
295303 * args ,
296- sm_scale = q_TNH . shape [ - 1 ] ** - 0.5 ,
304+ sm_scale = self . scale ,
297305 )
298306
299307 output_TNH , kv_cache = jax .jit (
0 commit comments