Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tpu_commons/models/jax/common/attention/deepseek_v3_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from dataclasses import dataclass
from typing import Any, Tuple

Expand Down Expand Up @@ -63,6 +64,7 @@ class MLA(nnx.Module):
attention_chunk_size: int | None = None
rope_input_ordering: str = "split"
quant: Any | None = None
rope_mscale_all_dim: float = 1.0

def __post_init__(self):
self.N = self.num_attention_heads
Expand All @@ -73,6 +75,13 @@ def __post_init__(self):

assert self.N == self.K, "N and K must be equal for MLA"

if self.rope_scaling["factor"] <= 1.0:
yarn_mscale = 1.0
else:
yarn_mscale = 0.1 * self.rope_mscale_all_dim * math.log(
self.rope_scaling["factor"]) + 1.0
self.scale = self.qk_head_dim**-0.5 * yarn_mscale**2

self.rope = DeepseekScalingRotaryEmbedding(
self.qk_rope_head_dim,
self.rope_theta,
Expand Down Expand Up @@ -180,7 +189,6 @@ def __call__(self,
# Concatenate the nope and rope queries.
q_TNH = jnp.concatenate([q_nope_TNH, q_rope_TNH], axis=-1)
# Multiple the query by scaling factor
q_TNH = q_TNH * self.qk_head_dim**-0.5
q_TNH = nnx.with_sharding_constraint(q_TNH, self.query_tnh)

with jax.named_scope("kv_proj"):
Expand Down Expand Up @@ -293,7 +301,7 @@ def attention(
def _ragged_paged_attention(*args):
return ragged_paged_attention(
*args,
sm_scale=q_TNH.shape[-1]**-0.5,
sm_scale=self.scale,
)

output_TNH, kv_cache = jax.jit(
Expand Down