Skip to content

Support Longchat #555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9aee8c6
add suport for different ropes
LiuXiaoxuanPKU Jul 24, 2023
80d6105
Merge branch 'main' into longchat
LiuXiaoxuanPKU Jul 24, 2023
5646fe1
format
LiuXiaoxuanPKU Jul 24, 2023
d8ff362
Merge branch 'vllm-project:main' into longchat
LiuXiaoxuanPKU Jul 25, 2023
ba1803d
Merge branch 'main' into longchat
LiuXiaoxuanPKU Jul 26, 2023
b08b7e9
statically allocate cache for cos_sin_cache
LiuXiaoxuanPKU Jul 26, 2023
31acb7f
Merge branch 'vllm-project:main' into longchat
LiuXiaoxuanPKU Jul 26, 2023
9acb8a0
Merge branch 'longchat' of github.com:LiuXiaoxuanPKU/vllm into longchat
LiuXiaoxuanPKU Jul 26, 2023
57ce875
format
LiuXiaoxuanPKU Jul 26, 2023
b61c8dd
minor bug fix and format
LiuXiaoxuanPKU Jul 26, 2023
fe402a6
add rope scaling as a cli arg so openai server can load rope scaled m…
winglian Aug 5, 2023
bb8e153
set rope-scaling arg as json.loads so it can load from cli
winglian Aug 5, 2023
58e7121
merge with main
LiuXiaoxuanPKU Aug 7, 2023
b9012fb
Merge pull request #1 from winglian/longchat-args
LiuXiaoxuanPKU Aug 7, 2023
14c65cc
merge with main
LiuXiaoxuanPKU Aug 7, 2023
fdc5ca3
fix style and add test
LiuXiaoxuanPKU Aug 10, 2023
e00f112
Merge branch 'main' into longchat
LiuXiaoxuanPKU Aug 10, 2023
659f7c9
more style
LiuXiaoxuanPKU Aug 10, 2023
7148513
add more tests
LiuXiaoxuanPKU Aug 10, 2023
7590773
Merge branch 'main' into longchat
LiuXiaoxuanPKU Sep 1, 2023
67058dc
Merge branch 'main' into longchat
LiuXiaoxuanPKU Sep 13, 2023
6c80e0a
merge
LiuXiaoxuanPKU Sep 14, 2023
762832b
merge and pass all tests
LiuXiaoxuanPKU Sep 14, 2023
a841dc3
format
LiuXiaoxuanPKU Sep 14, 2023
ae2368f
format
LiuXiaoxuanPKU Sep 14, 2023
6b5803f
Merge branch 'main' into longchat
LiuXiaoxuanPKU Sep 24, 2023
9657e2c
format
LiuXiaoxuanPKU Sep 24, 2023
d06e0b4
fix test
LiuXiaoxuanPKU Sep 24, 2023
647b801
merge
LiuXiaoxuanPKU Sep 24, 2023
b8058ea
change config check position
LiuXiaoxuanPKU Sep 24, 2023
b7ed435
remove ntk
LiuXiaoxuanPKU Sep 24, 2023
520a8e0
pytest
LiuXiaoxuanPKU Sep 24, 2023
0e9585e
add type
LiuXiaoxuanPKU Sep 24, 2023
b62beb8
Merge branch 'main' into longchat
WoosukKwon Sep 27, 2023
6fc06a6
Roll back arg_utils
WoosukKwon Sep 27, 2023
3e7c318
Minor
WoosukKwon Sep 27, 2023
058d1fa
Roll back
WoosukKwon Sep 27, 2023
9cdae35
Consider RoPE in determining max len
WoosukKwon Sep 27, 2023
24470e0
Fix LLaMA
WoosukKwon Sep 27, 2023
46b548f
Refactor rotary_embedding
WoosukKwon Sep 27, 2023
a955d44
Refactor attention with rope
WoosukKwon Sep 27, 2023
8f614c3
Minor fix
WoosukKwon Sep 27, 2023
e631736
Minor
WoosukKwon Sep 27, 2023
4cdbc0a
Remove rope_scaling from docstring
WoosukKwon Sep 27, 2023
41a3b4b
Temporarily remove tests
WoosukKwon Sep 27, 2023
198ee45
Minor
WoosukKwon Sep 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,17 @@ def _get_and_verify_max_len(
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if derived_max_model_len == float("inf"):
raise ValueError(
"When using rope_scaling, the model's config.json must "
"contain one of the following keys to determine the original "
f"maximum length of the model: {possible_keys}")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
derived_max_model_len *= scaling_factor

if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
Expand Down
63 changes: 25 additions & 38 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Multi-head attention."""
from typing import List, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.nn as nn
Expand All @@ -9,8 +9,10 @@

from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]

Expand Down Expand Up @@ -247,7 +249,7 @@ def forward(


class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with rotary embedding."""
"""PagedAttention with rotary positional embedding."""

def __init__(
self,
Expand All @@ -259,34 +261,26 @@ def __init__(
base: int = 10000,
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
self.is_neox_style = is_neox_style

# Create the cos and sin cache.
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)

# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
self.register_buffer("cos_sin_cache", cache, persistent=False)
if rope_scaling is None:
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

def forward(
self,
Expand All @@ -303,7 +297,7 @@ def forward(

Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
Expand All @@ -319,14 +313,7 @@ def forward(

# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
query, key = self.rotary_emb(positions, query, key)
return super().forward(
query,
key,
Expand Down
169 changes: 169 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
from typing import Tuple, Union

import torch
import torch.nn as nn

from vllm import pos_encoding_ops


class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style

cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
self.register_buffer("cos_sin_cache", cache, persistent=False)

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim))
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
dtype=torch.float,
device="cuda")

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pos_encoding_ops.rotary_embedding() is an in-place operation that
# updates the query and key tensors.
pos_encoding_ops.rotary_embedding(positions, query, key,
self.head_size, self.cos_sin_cache,
self.is_neox_style)
return query, key


class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.

Credits to the Reddit user /u/kaiokendev
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = t / self.scaling_factor

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache


class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.

Credits to the Reddit users /u/bloc97 and /u/emozilla
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda")

freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
8 changes: 6 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
Expand Down Expand Up @@ -135,7 +136,8 @@ def __init__(
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)

def forward(
self,
Expand Down Expand Up @@ -165,13 +167,15 @@ def __init__(
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
Expand Down