Skip to content

Commit

Permalink
[Model] LoRA gptbigcode implementation (vllm-project#3949)
Browse files Browse the repository at this point in the history
  • Loading branch information
raywanb authored May 22, 2024
1 parent 0aade5a commit 760a464
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
4 changes: 4 additions & 0 deletions csrc/punica/bgmv/bgmv_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
Expand All @@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
Expand Down Expand Up @@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
Expand All @@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _lora_ref_impl(
2560,
2752,
3072,
3328,
3456,
3584,
4096,
Expand All @@ -66,6 +67,7 @@ def _lora_ref_impl(
5504,
5632,
6144,
6400,
6848,
6912,
7168,
Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def from_local_checkpoint(
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules

if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
Expand Down
31 changes: 26 additions & 5 deletions vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import GPTBigCodeConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -191,14 +191,19 @@ def __init__(
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention

self.embed_dim = config.hidden_size

self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, cache_config, quant_config)
Expand Down Expand Up @@ -226,19 +231,35 @@ def forward(


class GPTBigCodeForCausalLM(nn.Module):
packed_modules_mapping = {"c_attn": ["c_attn"]}

supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]

embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}

embedding_padding_modules = []

def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()

def forward(
Expand Down

0 comments on commit 760a464

Please sign in to comment.