Skip to content

Implementation of Positional Interpolation (PI) Feature #690

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,15 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
new_decoder_arch_falcon = (
self.hf_config.model_type == "falcon"
and getattr(self.hf_config, "new_decoder_architecture", False))
new_decoder_arch_falcon = self.hf_config.model_type == "falcon" and getattr(
self.hf_config, "new_decoder_architecture", False)
if not new_decoder_arch_falcon and getattr(self.hf_config,
"multi_query", False):
# Multi-query attention, only one KV head.
return 1
# For Falcon:
if getattr(self.hf_config, "n_head_kv", None) is not None:
return (self.hf_config.n_head_kv //
parallel_config.tensor_parallel_size)
return self.hf_config.n_head_kv // parallel_config.tensor_parallel_size
# For LLaMA-2:
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
return (self.hf_config.num_key_value_heads //
Expand All @@ -118,22 +116,29 @@ def get_num_heads(self, parallel_config: "ParallelConfig") -> int:

def get_max_model_len(self) -> int:
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
length_keys = [
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
position_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
]
rope_scaling_factor = getattr(self.hf_config, "rope_scaling",
{}).get("factor", 1.0)
for key in length_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
for key in position_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len,
max_len_key * rope_scaling_factor)
return max_model_len

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
Expand Down
53 changes: 33 additions & 20 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch
import torch.nn as nn
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias,
)

from vllm import attention_ops
from vllm import cache_ops
Expand Down Expand Up @@ -52,11 +54,13 @@ class PagedAttention(nn.Module):
5. Output a flattened 1D tensor.
"""

def __init__(self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None) -> None:
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -68,7 +72,8 @@ def __init__(self,
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)
self.num_queries_per_kv,
)

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
Expand Down Expand Up @@ -215,8 +220,7 @@ def forward(
# When key_cache and value_cache are not provided, the new key
# and value vectors will not be cached.
num_valid_tokens = input_metadata.num_valid_tokens
if (num_valid_tokens > 0 and key_cache is not None
and value_cache is not None):
if num_valid_tokens > 0 and key_cache is not None and value_cache is not None:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
Expand All @@ -235,8 +239,11 @@ def forward(
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens], key_cache,
value_cache, input_metadata)
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata,
)

# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
Expand All @@ -255,12 +262,16 @@ def __init__(
max_position: int = 8192,
base: int = 10000,
num_kv_heads: Optional[int] = None,
rope_scaling_factor: float = 1,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)

self.rope_scaling_factor = rope_scaling_factor
max_position = max_position * self.rope_scaling_factor

# Create the cos and sin cache.
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim))
t = torch.arange(max_position).float()
t = torch.arange(max_position).float() / self.rope_scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
Expand All @@ -285,7 +296,7 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
""" PagedAttention forward pass with rotary embedding.
"""PagedAttention forward pass with rotary embedding.

Args:
positions: shape = [num_tokens]
Expand Down Expand Up @@ -326,12 +337,14 @@ def forward(
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""

def __init__(self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None) -> None:
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
num_kv_heads: Optional[int] = None,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads

Expand Down
107 changes: 71 additions & 36 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,19 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs

KVCache = Tuple[torch.Tensor, torch.Tensor]
Expand All @@ -56,16 +63,20 @@ def __init__(
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand All @@ -85,6 +96,8 @@ def __init__(
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 8192,
rope_scaling_factor: float = 1,
):
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -115,11 +128,15 @@ def __init__(
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
max_position=max_position,
num_kv_heads=self.num_kv_heads,
rope_scaling_factor=rope_scaling_factor,
)

def forward(
self,
Expand All @@ -143,10 +160,17 @@ class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
if config.rope_scaling is None:
rope_scaling_factor = 1
else:
rope_scaling_factor = config.rope_scaling["factor"]
max_position = config.get("max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
max_position=max_position,
rope_scaling_factor=rope_scaling_factor,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -235,11 +259,13 @@ def __init__(self, config):
self.config = config
self.model = LlamaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)

def forward(
Expand All @@ -257,18 +283,23 @@ def forward(
return next_tokens

_column_parallel_weights = [
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight",
"gate_proj.weight", "up_proj.weight"
"embed_tokens.weight",
"lm_head.weight",
"qkv_proj.weight",
"gate_proj.weight",
"up_proj.weight",
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
):
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
q_proj_shard_size = self.config.hidden_size // tp_size
kv_proj_shard_size = (self.config.hidden_size //
self.config.num_attention_heads *
self.config.num_key_value_heads // tp_size)
Expand All @@ -289,7 +320,7 @@ def load_weights(self,
if "embed_tokens" in name or "lm_head" in name:
param = state_dict[name]
# Consider padding in the vocab size.
padded_vocab_size = (param.shape[0] * tp_size)
padded_vocab_size = param.shape[0] * tp_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
Expand Down Expand Up @@ -333,7 +364,11 @@ def load_weights(self,
continue

param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank,
)