@@ -509,15 +509,12 @@ def __init__(
509509 ):
510510 super ().__init__ ()
511511
512- if rotary_dim != head_size :
513- raise ValueError (
514- f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
515- rotary_dim != head_size ({ rotary_dim } !={ head_size } )." )
516512 if is_neox_style is False :
517513 raise ValueError (
518514 "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
519515 )
520516
517+ self .rotary_dim = rotary_dim
521518 self .head_size = head_size
522519 self .max_position_embeddings = max_position_embeddings
523520 self .original_max_position_embeddings = original_max_position_embeddings
@@ -557,7 +554,7 @@ def __init__(
557554 def _compute_inv_freq (self , rescale_factors : List [float ]) -> torch .Tensor :
558555 rescale_factors = torch .tensor (rescale_factors , dtype = torch .float32 )
559556 inv_freq = 1.0 / (rescale_factors * (self .base ** (torch .arange (
560- 0 , self .head_size , 2 , dtype = torch .float ) / self .head_size )))
557+ 0 , self .rotary_dim , 2 , dtype = torch .float ) / self .rotary_dim )))
561558 return inv_freq
562559
563560 def _compute_cos_sin_cache (
@@ -596,8 +593,15 @@ def forward(
596593 cos = cos .repeat (1 , 2 ).unsqueeze (- 2 )
597594 sin = sin .repeat (1 , 2 ).unsqueeze (- 2 )
598595
599- query = query * cos + _rotate_neox (query ) * sin
600- key = key * cos + _rotate_neox (key ) * sin
596+ query_rot = query [..., :self .rotary_dim ]
597+ query_pass = query [..., self .rotary_dim :]
598+ query_rot = query_rot * cos + _rotate_neox (query_rot ) * sin
599+ query = torch .cat ((query_rot , query_pass ), dim = - 1 )
600+
601+ key_rot = key [..., :self .rotary_dim ]
602+ key_pass = key [..., self .rotary_dim :]
603+ key_rot = key_rot * cos + _rotate_neox (key_rot ) * sin
604+ key = torch .cat ((key_rot , key_pass ), dim = - 1 )
601605
602606 return query .flatten (- 2 ), key .flatten (- 2 )
603607
0 commit comments