From d9e9332b5d930746a71e1a9810a6ee56a547e113 Mon Sep 17 00:00:00 2001 From: raywanb <112235519+raywanb@users.noreply.github.com> Date: Thu, 23 May 2024 04:58:59 +0800 Subject: [PATCH] [Model] LoRA gptbigcode implementation (#3949) --- csrc/punica/bgmv/bgmv_config.h | 4 +++ tests/lora/test_punica.py | 2 ++ vllm/lora/models.py | 2 ++ vllm/model_executor/models/gpt_bigcode.py | 31 +++++++++++++++++++---- 4 files changed, 34 insertions(+), 5 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 98ac8de779e13..4b376261d30d2 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 4096) \ @@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6144) \ + f(in_T, out_T, W_T, narrow, 6400) \ f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ @@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ @@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6400, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 193e3906997c4..f021c003b1322 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -58,6 +58,7 @@ def _lora_ref_impl( 2560, 2752, 3072, + 3328, 3456, 3584, 4096, @@ -66,6 +67,7 @@ def _lora_ref_impl( 5504, 5632, 6144, + 6400, 6848, 6912, 7168, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index a2092d31ea9aa..3e82856866d85 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -310,7 +310,9 @@ def from_local_checkpoint( if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules + if unexpected_modules: + print(unexpected_modules, "modules") raise ValueError( f"While loading {lora_dir}, expected" f" target modules in {expected_lora_modules}" diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f488ef40039c0..69b75763e9a3d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -191,14 +191,19 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ GPTBigCodeBlock(config, cache_config, quant_config) @@ -226,19 +231,35 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] def __init__( self, config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config self.quant_config = quant_config - self.transformer = GPTBigCodeModel(config, cache_config, quant_config) + self.transformer = GPTBigCodeModel(config, cache_config, quant_config, + lora_config) self.lm_head_weight = self.transformer.wte.weight - self.logits_processor = LogitsProcessor(config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) self.sampler = Sampler() def forward(