From de60a3fb93957dce6b242299b5d163f02ef7f383 Mon Sep 17 00:00:00 2001 From: avideci <61653911+avideci@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:29:33 +0200 Subject: [PATCH] Added DeciLM-7b and DeciLM-7b-instruct (#2062) --- README.md | 1 + docs/source/models/supported_models.rst | 3 + tests/models/test_models.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/decilm.py | 123 ++++++++++++++++++++++++ 5 files changed, 129 insertions(+) create mode 100644 vllm/model_executor/models/decilm.py diff --git a/README.md b/README.md index c6e6a3c7379db..13c654a218e89 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c95b158e871fe..71808436dc114 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. + * - :code:`DeciLMForCausalLM` + - DeciLM + - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e65c424c601a2..518eae201ed32 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -8,6 +8,7 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", "bigcode/tiny_starcoder_py", diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ab9a1636ad13f..f60ea640359b3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -17,6 +17,7 @@ "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py new file mode 100644 index 0000000000000..984be0cccd16d --- /dev/null +++ b/vllm/model_executor/models/decilm.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 DeciAI Research Team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeciLM model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) + + +class DeciLMForCausalLM(LlamaForCausalLM): + """ + Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct. + Based on the llama executor. + + The main difference is that DeciLM uses Variable Grouped Query Attention. + The constant number of GQA heads in the decoder is overriden with a value + per layer. + + Usually, in the HuggingFace implementation, instead of + "config.num_key_value_heads", we use + "config.num_key_value_heads_per_layer[i]" which varies. + + Currently, PagedAttention does not work well with variable GQA, so we + normalize the weights upon loading, and use uniform GQA with the max value + instead. + """ + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + config.num_key_value_heads = max(config.num_key_value_heads_per_layer) + delattr(config, "num_key_value_heads_per_layer") + super().__init__(config=config, linear_method=linear_method) + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + if "k_proj" in name or "v_proj" in name: + loaded_weight = self._degroup_weight(loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: + hidden_size = self.config.hidden_size + head_size = self.config.hidden_size // self.config.num_attention_heads + target_num_kv_heads = self.config.num_key_value_heads + num_kv_heads = loaded_weight.shape[0] // head_size + n_repeats = target_num_kv_heads / num_kv_heads + assert n_repeats == int(n_repeats) + + n_repeats = int(n_repeats) + loaded_weight = loaded_weight.view(num_kv_heads, head_size, + hidden_size) + loaded_weight = torch.repeat_interleave(loaded_weight, + repeats=n_repeats, + dim=0) + loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size, + hidden_size) + + return loaded_weight