Skip to content

YaRN support implementation #1264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ __global__ void silu_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
Expand All @@ -30,7 +30,7 @@ void silu_and_mul(
torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
int num_tokens = input.numel() / input.size(-1);
int64_t num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2;

dim3 grid(num_tokens);
Expand All @@ -55,8 +55,8 @@ __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
Expand All @@ -67,7 +67,7 @@ __global__ void activation_kernel(
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int num_tokens = input.numel() / d; \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
Expand Down
2 changes: 1 addition & 1 deletion csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void rotary_embedding(
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.numel() / query.size(-1);
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def _get_and_verify_max_len(
if rope_scaling is not None:
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
derived_max_model_len = rope_scaling[
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if max_model_len is None:
Expand Down
15 changes: 14 additions & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)
RotaryEmbedding, YaRNScalingRotaryEmbedding)

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
Expand Down Expand Up @@ -334,6 +334,19 @@ def __init__(
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor",
"beta_fast", "beta_slow")
}
self.rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

Expand Down
104 changes: 104 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import math
from typing import Tuple, Union

import torch
Expand Down Expand Up @@ -167,3 +168,106 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> float:
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))


# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> int:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case


def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype,
device: torch.device) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=dtype, device=device) -
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func


def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0


class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float,
device="cuda")) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache