Skip to content

Commit

Permalink
Fix RoPE
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon committed Aug 19, 2024
1 parent c0cc352 commit 37ea0c1
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import torch.nn as nn

from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform


def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -89,8 +88,6 @@ def __init__(
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)

self.use_native2 = current_platform.is_tpu() and is_neox_style

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
Expand Down Expand Up @@ -229,9 +226,10 @@ def forward_cpu(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# forward_native() is too complex to be optimized by torch.compile.
# Fall back to the custom C++ kernel.
return self.forward_cuda(positions, query, key, offsets)
if self.is_neox_style:
return self.forward_native2(positions, query, key, offsets)
else:
return self.forward_native(positions, query, key, offsets)

def forward_xpu(
self,
Expand Down Expand Up @@ -263,9 +261,10 @@ def forward_tpu(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
if self.is_neox_style:
return self.forward_native2(positions, query, key, offsets)
else:
return self.forward_native(positions, query, key, offsets)

def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
Expand Down

0 comments on commit 37ea0c1

Please sign in to comment.