Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fix Phi-3 Long RoPE scaling implementation (vllm-project#5628)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShukantPal authored and robertgshaw2-neuralmagic committed Jun 23, 2024
1 parent a8b75a4 commit a0d8ed2
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ def __init__(
dtype: torch.dtype,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
short_mscale: float = 1.0,
long_mscale: float = 1.0,
):
super().__init__()

Expand All @@ -530,6 +530,16 @@ def __init__(
self.short_mscale = short_mscale
self.long_mscale = long_mscale

scale = (self.max_position_embeddings /
self.original_max_position_embeddings)

if scale <= 1.0:
self.scaling_factor = 1.0
else:
self.scaling_factor = math.sqrt(
1 + math.log(scale) /
math.log(self.original_max_position_embeddings))

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
Expand Down Expand Up @@ -565,8 +575,8 @@ def _compute_cos_sin_cache(
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cos = freqs.cos() * mscale * self.scaling_factor
sin = freqs.sin() * mscale * self.scaling_factor
cache = torch.cat((cos, sin), dim=-1)
return cache

Expand Down

0 comments on commit a0d8ed2

Please sign in to comment.