Skip to content

Commit 3fc41f4

Browse files
DarkLight1337jimpang
authored andcommitted
[Model] Add base class for LoRA-supported models (vllm-project#5018)
1 parent adf1efe commit 3fc41f4

File tree

20 files changed

+270
-75
lines changed

20 files changed

+270
-75
lines changed

docs/source/models/lora.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Using LoRA adapters
44
===================
55

66
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
7+
8+
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
9+
710
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
811
them locally with
912

vllm/lora/lora.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Sequence as GenericSequence
33

44
import torch
5+
import torch.types
56

67
from vllm.utils import is_pin_memory_available
78

@@ -64,7 +65,7 @@ def create_dummy_lora_weights(
6465
output_dim: int,
6566
rank: int,
6667
dtype: torch.dtype,
67-
device: torch.device,
68+
device: torch.types.Device,
6869
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
6970
pin_memory = str(device) == "cpu" and is_pin_memory_available()
7071
lora_a = torch.zeros([input_dim, rank],

vllm/lora/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
1919
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
2020
parse_fine_tuned_lora_name, replace_submodule)
21+
from vllm.model_executor.models.interfaces import SupportsLoRA
2122
from vllm.utils import LRUCache, is_pin_memory_available
2223

2324
logger = init_logger(__name__)
@@ -363,7 +364,7 @@ class LoRAModelManager:
363364

364365
def __init__(
365366
self,
366-
model: nn.Module,
367+
model: SupportsLoRA,
367368
max_num_seqs: int,
368369
max_num_batched_tokens: int,
369370
vocab_size: int,
@@ -411,7 +412,7 @@ def __init__(
411412
# embeddings_indices
412413
self.indices_len: List[Optional[int]] = [None] * 4
413414

414-
self.model: nn.Module = model
415+
self.model = model
415416
if hasattr(self.model, "supported_lora_modules"):
416417
self.supported_lora_modules = copy.deepcopy(
417418
self.model.supported_lora_modules)
@@ -428,7 +429,6 @@ def __init__(
428429
self._active_loras: Dict[int, None] = {}
429430
self._last_mapping: Optional[LoRAMapping] = None
430431
self._create_lora_modules()
431-
self.model.lora_manager = self
432432

433433
@property
434434
def capacity(self) -> int:

vllm/model_executor/model_loader/loader.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
3333
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
3434
pt_weights_iterator, safetensors_weights_iterator)
35-
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
35+
from vllm.model_executor.models.interfaces import (supports_lora,
36+
supports_vision)
3637
from vllm.model_executor.utils import set_weight_attrs
3738
from vllm.utils import is_tpu
3839

@@ -64,26 +65,31 @@ def _get_quantization_config(
6465

6566

6667
def _get_model_initialization_kwargs(
67-
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
68-
vision_language_config: Optional[VisionLanguageConfig]
68+
model_class: Type[nn.Module],
69+
lora_config: Optional[LoRAConfig],
70+
vlm_config: Optional[VisionLanguageConfig],
6971
) -> Dict[str, Any]:
7072
"""Get extra kwargs for model initialization."""
7173
extra_kwargs: Dict[str, Any] = {}
72-
if hasattr(model_class, "supported_lora_modules"):
74+
75+
if supports_lora(model_class):
76+
# lora_config=None is used to disable LoRA
7377
extra_kwargs["lora_config"] = lora_config
7478
elif lora_config:
7579
raise ValueError(
7680
f"Model {model_class.__name__} does not support LoRA, "
7781
"but LoRA is enabled. Support for this model may "
7882
"be added in the future. If this is important to you, "
7983
"please open an issue on github.")
80-
elif issubclass(model_class, VisionLanguageModelBase):
81-
if vision_language_config is None:
84+
85+
if supports_vision(model_class):
86+
if vlm_config is None:
8287
raise ValueError("Provide `image_input_type` and other vision "
8388
"related configurations through LLM entrypoint "
8489
"or engine arguments.")
8590

86-
extra_kwargs["vision_language_config"] = vision_language_config
91+
extra_kwargs["vlm_config"] = vlm_config
92+
8793
return extra_kwargs
8894

8995

vllm/model_executor/models/baichuan.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from vllm.model_executor.sampling_metadata import SamplingMetadata
4646
from vllm.sequence import SamplerOutput
4747

48+
from .interfaces import SupportsLoRA
49+
4850

4951
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
5052
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
@@ -292,7 +294,9 @@ def forward(
292294
return hidden_states
293295

294296

295-
class BaiChuanBaseForCausalLM(nn.Module):
297+
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
298+
supports_lora = True
299+
296300
packed_modules_mapping = {
297301
"W_pack": ["W_pack"],
298302
"gate_up_proj": [
@@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
312316

313317
def __init__(
314318
self,
315-
config,
319+
config: PretrainedConfig,
316320
position_embedding: str,
317321
cache_config: Optional[CacheConfig] = None,
318322
quant_config: Optional[QuantizationConfig] = None,
319323
lora_config: Optional[LoRAConfig] = None,
320324
):
321325
super().__init__()
326+
322327
self.config = config
328+
self.lora_config = lora_config
329+
323330
self.quant_config = quant_config
324331
self.model = BaiChuanModel(config, position_embedding, cache_config,
325332
quant_config)

vllm/model_executor/models/chatglm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from vllm.sequence import SamplerOutput
2929
from vllm.transformers_utils.configs import ChatGLMConfig
3030

31+
from .interfaces import SupportsLoRA
32+
3133

3234
class GLMAttention(nn.Module):
3335

@@ -322,7 +324,9 @@ def forward(
322324
return hidden_states
323325

324326

325-
class ChatGLMForCausalLM(nn.Module):
327+
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
328+
supports_lora = True
329+
326330
packed_modules_mapping = {
327331
"query_key_value": ["query_key_value"],
328332
"dense_h_to_4h": ["dense_h_to_4h"]
@@ -345,7 +349,10 @@ def __init__(
345349
lora_config: Optional[LoRAConfig] = None,
346350
):
347351
super().__init__()
348-
self.config: ChatGLMConfig = config
352+
353+
self.config = config
354+
self.lora_config = lora_config
355+
349356
self.quant_config = quant_config
350357
self.max_position_embeddings = getattr(config, "max_sequence_length",
351358
8192)

vllm/model_executor/models/decilm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from typing import Iterable, Optional, Tuple
2727

2828
import torch
29-
from transformers import PretrainedConfig
29+
from transformers import LlamaConfig
3030

3131
from vllm.config import CacheConfig, LoRAConfig
3232
from vllm.model_executor.layers.quantization.base_config import (
@@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
5555

5656
def __init__(
5757
self,
58-
config: Optional[PretrainedConfig] = None,
58+
config: LlamaConfig,
5959
cache_config: Optional[CacheConfig] = None,
6060
quant_config: Optional[QuantizationConfig] = None,
6161
lora_config: Optional[LoRAConfig] = None,

vllm/model_executor/models/gemma.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from vllm.model_executor.sampling_metadata import SamplingMetadata
4242
from vllm.sequence import SamplerOutput
4343

44+
from .interfaces import SupportsLoRA
45+
4446
logger = init_logger(__name__)
4547

4648

@@ -288,7 +290,9 @@ def forward(
288290
return hidden_states
289291

290292

291-
class GemmaForCausalLM(nn.Module):
293+
class GemmaForCausalLM(nn.Module, SupportsLoRA):
294+
supports_lora = True
295+
292296
packed_modules_mapping = {
293297
"qkv_proj": [
294298
"q_proj",
@@ -319,9 +323,11 @@ def __init__(
319323
quant_config: Optional[QuantizationConfig] = None,
320324
lora_config: Optional[LoRAConfig] = None,
321325
) -> None:
322-
del lora_config # Unused.
323326
super().__init__()
327+
324328
self.config = config
329+
self.lora_config = lora_config
330+
325331
self.quant_config = quant_config
326332
self.model = GemmaModel(config, cache_config, quant_config)
327333
self.logits_processor = LogitsProcessor(config.vocab_size)

vllm/model_executor/models/gpt_bigcode.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from vllm.model_executor.sampling_metadata import SamplingMetadata
4242
from vllm.sequence import SamplerOutput
4343

44+
from .interfaces import SupportsLoRA
45+
4446

4547
class GPTBigCodeAttention(nn.Module):
4648

@@ -230,7 +232,9 @@ def forward(
230232
return hidden_states
231233

232234

233-
class GPTBigCodeForCausalLM(nn.Module):
235+
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
236+
supports_lora = True
237+
234238
packed_modules_mapping = {"c_attn": ["c_attn"]}
235239

236240
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
@@ -250,7 +254,10 @@ def __init__(
250254
lora_config: Optional[LoRAConfig] = None,
251255
):
252256
super().__init__()
257+
253258
self.config = config
259+
self.lora_config = lora_config
260+
254261
self.quant_config = quant_config
255262
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
256263
lora_config)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
2+
Union, overload, runtime_checkable)
3+
4+
from typing_extensions import TypeGuard
5+
6+
from vllm.config import LoRAConfig, VisionLanguageConfig
7+
from vllm.logger import init_logger
8+
9+
logger = init_logger(__name__)
10+
11+
12+
@runtime_checkable
13+
class SupportsVision(Protocol):
14+
"""The interface required for all vision language models (VLMs)."""
15+
16+
supports_vision: ClassVar[Literal[True]]
17+
18+
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
19+
...
20+
21+
22+
# We can't use runtime_checkable with ClassVar for issubclass checks
23+
# so we need to treat the class as an instance and use isinstance instead
24+
@runtime_checkable
25+
class _SupportsVisionType(Protocol):
26+
supports_vision: Literal[True]
27+
28+
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
29+
...
30+
31+
32+
@overload
33+
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
34+
...
35+
36+
37+
@overload
38+
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
39+
...
40+
41+
42+
def supports_vision(
43+
model: Union[Type[object], object],
44+
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
45+
if isinstance(model, type):
46+
return isinstance(model, _SupportsVisionType)
47+
48+
return isinstance(model, SupportsVision)
49+
50+
51+
@runtime_checkable
52+
class SupportsLoRA(Protocol):
53+
"""The interface required for all models that support LoRA."""
54+
55+
supports_lora: ClassVar[Literal[True]]
56+
57+
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
58+
supported_lora_modules: ClassVar[List[str]]
59+
embedding_modules: ClassVar[Dict[str, str]]
60+
embedding_padding_modules: ClassVar[List[str]]
61+
62+
# lora_config is None when LoRA is not enabled
63+
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
64+
...
65+
66+
67+
# We can't use runtime_checkable with ClassVar for issubclass checks
68+
# so we need to treat the class as an instance and use isinstance instead
69+
@runtime_checkable
70+
class _SupportsLoRAType(Protocol):
71+
supports_lora: Literal[True]
72+
73+
packed_modules_mapping: Dict[str, List[str]]
74+
supported_lora_modules: List[str]
75+
embedding_modules: Dict[str, str]
76+
embedding_padding_modules: List[str]
77+
78+
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
79+
...
80+
81+
82+
@overload
83+
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
84+
...
85+
86+
87+
@overload
88+
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
89+
...
90+
91+
92+
def supports_lora(
93+
model: Union[Type[object], object],
94+
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
95+
result = _supports_lora(model)
96+
97+
if not result:
98+
lora_attrs = (
99+
"packed_modules_mapping",
100+
"supported_lora_modules",
101+
"embedding_modules",
102+
"embedding_padding_modules",
103+
)
104+
missing_attrs = tuple(attr for attr in lora_attrs
105+
if not hasattr(model, attr))
106+
107+
if getattr(model, "supports_lora", False):
108+
if missing_attrs:
109+
logger.warning(
110+
"The model (%s) sets `supports_lora=True`, "
111+
"but is missing LoRA-specific attributes: %s",
112+
model,
113+
missing_attrs,
114+
)
115+
else:
116+
if not missing_attrs:
117+
logger.warning(
118+
"The model (%s) contains all LoRA-specific attributes, "
119+
"but does not set `supports_lora=True`.", model)
120+
121+
return result
122+
123+
124+
def _supports_lora(
125+
model: Union[Type[object], object],
126+
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
127+
if isinstance(model, type):
128+
return isinstance(model, _SupportsLoRAType)
129+
130+
return isinstance(model, SupportsLoRA)

0 commit comments

Comments
 (0)