Skip to content

Commit 75e596e

Browse files
JJJYmmmliuye.hj
authored andcommitted
[feat] Support MRoPE + YaRN (vllm-project#25384)
Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent cedceb7 commit 75e596e

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

vllm/model_executor/layers/rotary_embedding/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,23 @@ def get_rope(
153153
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
154154
"beta_slow")
155155
}
156-
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
157-
original_max_position,
158-
base, is_neox_style,
159-
scaling_factor, dtype,
160-
**extra_kwargs)
156+
if "mrope_section" in rope_scaling:
157+
rotary_emb = MRotaryEmbedding(
158+
head_size,
159+
rotary_dim,
160+
original_max_position,
161+
base,
162+
is_neox_style,
163+
dtype,
164+
mrope_section=rope_scaling["mrope_section"],
165+
mrope_interleaved=rope_scaling.get("mrope_interleaved",
166+
False),
167+
scaling_factor=scaling_factor,
168+
**extra_kwargs)
169+
else:
170+
rotary_emb = YaRNScalingRotaryEmbedding(
171+
head_size, rotary_dim, original_max_position, base,
172+
is_neox_style, scaling_factor, dtype, **extra_kwargs)
161173
elif scaling_type == "deepseek_yarn":
162174
scaling_factor = rope_scaling["factor"]
163175
original_max_position = rope_scaling[

vllm/model_executor/layers/rotary_embedding/mrope.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .base import RotaryEmbedding
1414
from .common import apply_rotary_emb_dispatch
15+
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
1516

1617

1718
@triton.jit
@@ -213,7 +214,27 @@ def __init__(
213214
dtype: torch.dtype,
214215
mrope_section: Optional[list[int]] = None,
215216
mrope_interleaved: bool = False,
217+
# YaRN parameters.
218+
*,
219+
scaling_factor: Optional[float] = None,
220+
extrapolation_factor: float = 1,
221+
attn_factor: float = 1,
222+
beta_fast: int = 32,
223+
beta_slow: int = 1,
216224
) -> None:
225+
226+
self.scaling_factor = scaling_factor
227+
self.extrapolation_factor = extrapolation_factor
228+
self.attn_factor = attn_factor
229+
self.beta_fast = beta_fast
230+
self.beta_slow = beta_slow
231+
if self.scaling_factor is not None:
232+
# Get n-d magnitude scaling corrected for interpolation
233+
self.mscale = float(
234+
yarn_get_mscale(self.scaling_factor) * attn_factor)
235+
else:
236+
self.mscale = 1.0
237+
217238
# In Qwen2.5-VL, the maximum index value is related to the duration of
218239
# the input video. We enlarge max_position_embeddings to 4 times to get
219240
# a larger the cos and sin cache.
@@ -226,6 +247,16 @@ def __init__(
226247
if self.mrope_section:
227248
assert sum(self.mrope_section) == rotary_dim // 2
228249

250+
def _compute_inv_freq(self, base: float) -> torch.Tensor:
251+
if self.scaling_factor is None:
252+
return super()._compute_inv_freq(base)
253+
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)
254+
255+
def _compute_cos_sin_cache(self) -> torch.Tensor:
256+
if self.scaling_factor is None:
257+
return super()._compute_cos_sin_cache()
258+
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)
259+
229260
def forward_native(
230261
self,
231262
positions: torch.Tensor,

0 commit comments

Comments
 (0)