1212
1313from .base import RotaryEmbedding
1414from .common import apply_rotary_emb_dispatch
15+ from .yarn_scaling_rope import YaRNScalingRotaryEmbedding , yarn_get_mscale
1516
1617
1718@triton .jit
@@ -213,7 +214,27 @@ def __init__(
213214 dtype : torch .dtype ,
214215 mrope_section : Optional [list [int ]] = None ,
215216 mrope_interleaved : bool = False ,
217+ # YaRN parameters.
218+ * ,
219+ scaling_factor : Optional [float ] = None ,
220+ extrapolation_factor : float = 1 ,
221+ attn_factor : float = 1 ,
222+ beta_fast : int = 32 ,
223+ beta_slow : int = 1 ,
216224 ) -> None :
225+
226+ self .scaling_factor = scaling_factor
227+ self .extrapolation_factor = extrapolation_factor
228+ self .attn_factor = attn_factor
229+ self .beta_fast = beta_fast
230+ self .beta_slow = beta_slow
231+ if self .scaling_factor is not None :
232+ # Get n-d magnitude scaling corrected for interpolation
233+ self .mscale = float (
234+ yarn_get_mscale (self .scaling_factor ) * attn_factor )
235+ else :
236+ self .mscale = 1.0
237+
217238 # In Qwen2.5-VL, the maximum index value is related to the duration of
218239 # the input video. We enlarge max_position_embeddings to 4 times to get
219240 # a larger the cos and sin cache.
@@ -226,6 +247,16 @@ def __init__(
226247 if self .mrope_section :
227248 assert sum (self .mrope_section ) == rotary_dim // 2
228249
250+ def _compute_inv_freq (self , base : float ) -> torch .Tensor :
251+ if self .scaling_factor is None :
252+ return super ()._compute_inv_freq (base )
253+ return YaRNScalingRotaryEmbedding ._compute_inv_freq (self , base )
254+
255+ def _compute_cos_sin_cache (self ) -> torch .Tensor :
256+ if self .scaling_factor is None :
257+ return super ()._compute_cos_sin_cache ()
258+ return YaRNScalingRotaryEmbedding ._compute_cos_sin_cache (self )
259+
229260 def forward_native (
230261 self ,
231262 positions : torch .Tensor ,
0 commit comments