forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for aquila (vllm-project#663)
* add aquila Signed-off-by: ftgreat <ftgreat@163.com> * fix some bug Signed-off-by: shunxing1234 <xw747777271@gmail.com> * delete pdb Signed-off-by: shunxing1234 <xw747777271@gmail.com> * fix bugs Signed-off-by: shunxing1234 <xw747777271@gmail.com> * fix bugs Signed-off-by: shunxing1234 <xw747777271@gmail.com> * delete whitespace Signed-off-by: shunxing1234 <xw747777271@gmail.com> * format * fix order --------- Signed-off-by: ftgreat <ftgreat@163.com> Signed-off-by: shunxing1234 <xw747777271@gmail.com> Co-authored-by: ftgreat <ftgreat@163.com>
- Loading branch information
1 parent
4f85847
commit ad5f2fe
Showing
6 changed files
with
431 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,362 @@ | ||
# coding=utf-8 | ||
# Adapted from | ||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py | ||
# Copyright 2023 The vLLM team. | ||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | ||
# | ||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | ||
# and OPT implementations in this library. It has been modified from its | ||
# original forms to accommodate minor architectural differences compared | ||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Inference-only LLaMA model compatible with HuggingFace weights. | ||
The input of the model is flattened to a 1D tensor of tokens. The model uses | ||
InputMetadata to extract the original 2D shape of the input. | ||
""" | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, | ||
load_tensor_parallel_weights) | ||
from vllm.model_executor.parallel_utils.parallel_state import ( | ||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | ||
from vllm.model_executor.parallel_utils.tensor_parallel import ( | ||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) | ||
from vllm.sequence import SequenceOutputs | ||
from vllm.transformers_utils.configs.aquila import AquilaConfig | ||
|
||
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
class AquilaMLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
intermediate_size: int, | ||
hidden_act: str, | ||
): | ||
super().__init__() | ||
self.gate_up_proj = ColumnParallelLinear(hidden_size, | ||
2 * intermediate_size, | ||
bias=False, | ||
gather_output=False, | ||
perform_initialization=False) | ||
self.down_proj = RowParallelLinear(intermediate_size, | ||
hidden_size, | ||
bias=False, | ||
input_is_parallel=True, | ||
perform_initialization=False) | ||
if hidden_act != "silu": | ||
raise ValueError(f"Unsupported activation: {hidden_act}. " | ||
"Only silu is supported for now.") | ||
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, x): | ||
gate_up, _ = self.gate_up_proj(x) | ||
x = self.act_fn(gate_up) | ||
x, _ = self.down_proj(x) | ||
return x | ||
|
||
|
||
class AquilaRMSNorm(nn.Module): | ||
|
||
def __init__(self, hidden_size, eps=1e-6): | ||
""" | ||
AquilaRMSNorm is equivalent to T5LayerNorm | ||
""" | ||
super().__init__() | ||
self.weight = nn.Parameter(torch.ones(hidden_size)) | ||
self.variance_epsilon = eps | ||
|
||
def forward(self, hidden_states): | ||
input_dtype = hidden_states.dtype | ||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, | ||
keepdim=True) | ||
hidden_states = hidden_states * torch.rsqrt(variance + | ||
self.variance_epsilon) | ||
|
||
return (self.weight * hidden_states).to(input_dtype) | ||
|
||
|
||
class AquilaAttention(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
num_kv_heads: int, | ||
): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
tp_size = get_tensor_model_parallel_world_size() | ||
self.total_num_heads = num_heads | ||
assert self.total_num_heads % tp_size == 0 | ||
self.num_heads = self.total_num_heads // tp_size | ||
self.total_num_kv_heads = num_kv_heads | ||
assert self.total_num_kv_heads % tp_size == 0 | ||
self.num_kv_heads = self.total_num_kv_heads // tp_size | ||
self.head_dim = hidden_size // self.total_num_heads | ||
self.q_size = self.num_heads * self.head_dim | ||
self.kv_size = self.num_kv_heads * self.head_dim | ||
self.scaling = self.head_dim**-0.5 | ||
|
||
self.qkv_proj = ColumnParallelLinear( | ||
hidden_size, | ||
(self.total_num_heads + 2 * self.total_num_kv_heads) * | ||
self.head_dim, | ||
bias=False, | ||
gather_output=False, | ||
perform_initialization=False, | ||
) | ||
self.o_proj = RowParallelLinear( | ||
self.total_num_heads * self.head_dim, | ||
hidden_size, | ||
bias=False, | ||
input_is_parallel=True, | ||
perform_initialization=False, | ||
) | ||
self.attn = PagedAttentionWithRoPE( | ||
self.num_heads, | ||
self.head_dim, | ||
self.scaling, | ||
rotary_dim=self.head_dim, | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | ||
k_cache, v_cache = kv_cache | ||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, | ||
input_metadata, cache_event) | ||
output, _ = self.o_proj(attn_output) | ||
return output | ||
|
||
|
||
class AquilaDecoderLayer(nn.Module): | ||
|
||
def __init__(self, config: AquilaConfig): | ||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
self.self_attn = AquilaAttention( | ||
hidden_size=self.hidden_size, | ||
num_heads=config.num_attention_heads, | ||
num_kv_heads=config.num_attention_heads, | ||
) | ||
self.mlp = AquilaMLP( | ||
hidden_size=self.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
hidden_act=config.hidden_act, | ||
) | ||
self.input_layernorm = AquilaRMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, | ||
eps=config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
# Self Attention | ||
residual = hidden_states | ||
hidden_states = self.input_layernorm(hidden_states) | ||
hidden_states = self.self_attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
kv_cache=kv_cache, | ||
input_metadata=input_metadata, | ||
cache_event=cache_event, | ||
) | ||
hidden_states = residual + hidden_states | ||
|
||
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.post_attention_layernorm(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states | ||
return hidden_states | ||
|
||
|
||
class AquilaModel(nn.Module): | ||
|
||
def __init__(self, config: AquilaConfig): | ||
super().__init__() | ||
self.config = config | ||
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
|
||
#vocab_size = ((config.vocab_size + 63) // 64) * 64 | ||
self.embed_tokens = VocabParallelEmbedding( | ||
config.vocab_size, | ||
config.hidden_size, | ||
perform_initialization=False) | ||
self.layers = nn.ModuleList([ | ||
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) | ||
]) | ||
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> torch.Tensor: | ||
hidden_states = self.embed_tokens(input_ids) | ||
for i in range(len(self.layers)): | ||
if cache_events is None: | ||
cache_event = None | ||
else: | ||
cache_event = cache_events[i] | ||
layer = self.layers[i] | ||
hidden_states = layer( | ||
positions, | ||
hidden_states, | ||
kv_caches[i], | ||
input_metadata, | ||
cache_event, | ||
) | ||
hidden_states = self.norm(hidden_states) | ||
|
||
return hidden_states | ||
|
||
|
||
class AquilaForCausalLM(nn.Module): | ||
|
||
def __init__(self, config): | ||
super().__init__() | ||
self.config = config | ||
self.model = AquilaModel(config) | ||
vocab_size = ((config.vocab_size + 63) // 64) * 64 | ||
self.lm_head = ColumnParallelLinear(config.hidden_size, | ||
vocab_size, | ||
bias=False, | ||
gather_output=False, | ||
perform_initialization=False) | ||
self.sampler = Sampler(config.vocab_size) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> Dict[int, SequenceOutputs]: | ||
hidden_states = self.model(input_ids, positions, kv_caches, | ||
input_metadata, cache_events) | ||
next_tokens = self.sampler(self.lm_head.weight, hidden_states, | ||
input_metadata) | ||
return next_tokens | ||
|
||
_column_parallel_weights = [ | ||
"embed_tokens.weight", "lm_head.weight", "qkv_proj.weight", | ||
"gate_proj.weight", "up_proj.weight" | ||
] | ||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"] | ||
|
||
def load_weights(self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
use_np_cache: bool = False): | ||
tp_size = get_tensor_model_parallel_world_size() | ||
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | ||
q_proj_shard_size = (self.config.hidden_size // tp_size) | ||
kv_proj_shard_size = (self.config.hidden_size // | ||
self.config.num_attention_heads * | ||
self.config.num_attention_heads // tp_size) | ||
attention_weight_specs = [ | ||
# (weight_name, shard_size, offset) | ||
("q_proj", q_proj_shard_size, 0), | ||
("k_proj", kv_proj_shard_size, q_proj_shard_size), | ||
("v_proj", kv_proj_shard_size, | ||
q_proj_shard_size + kv_proj_shard_size), | ||
] | ||
state_dict = self.state_dict() | ||
|
||
for name, loaded_weight in hf_model_weights_iterator( | ||
model_name_or_path, cache_dir, use_np_cache): | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
|
||
if "embed_tokens" in name or "lm_head" in name: | ||
param = state_dict[name] | ||
# Consider padding in the vocab size. | ||
padded_vocab_size = (param.shape[0] * tp_size) | ||
num_extra_rows = padded_vocab_size - self.config.vocab_size | ||
extra_rows = torch.empty(num_extra_rows, | ||
loaded_weight.shape[1]) | ||
extra_rows = extra_rows.to(loaded_weight) | ||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | ||
|
||
is_attention_weight = False | ||
for weight_name, shard_size, offset in attention_weight_specs: | ||
if weight_name not in name: | ||
continue | ||
param = state_dict[name.replace(weight_name, "qkv_proj")] | ||
|
||
loaded_weight = loaded_weight[ | ||
shard_size * tensor_model_parallel_rank:shard_size * | ||
(tensor_model_parallel_rank + 1)] | ||
param_slice = param.data[offset:offset + shard_size] | ||
assert param_slice.shape == loaded_weight.shape | ||
|
||
param_slice.copy_(loaded_weight) | ||
is_attention_weight = True | ||
break | ||
if is_attention_weight: | ||
continue | ||
|
||
is_gate_up_weight = False | ||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | ||
if weight_name not in name: | ||
continue | ||
param = state_dict[name.replace(weight_name, "gate_up_proj")] | ||
shard_size = param.shape[0] // 2 | ||
loaded_weight = loaded_weight[ | ||
shard_size * tensor_model_parallel_rank:shard_size * | ||
(tensor_model_parallel_rank + 1)] | ||
param_slice = param.data[shard_size * stride_id:shard_size * | ||
(stride_id + 1)] | ||
assert param_slice.shape == loaded_weight.shape | ||
param_slice.copy_(loaded_weight) | ||
is_gate_up_weight = True | ||
break | ||
if is_gate_up_weight: | ||
continue | ||
|
||
param = state_dict[name] | ||
load_tensor_parallel_weights(param, loaded_weight, name, | ||
self._column_parallel_weights, | ||
self._row_parallel_weights, | ||
tensor_model_parallel_rank) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.