Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add base class for LoRA-supported models #5018

Merged
merged 13 commits into from
Jun 27, 2024
3 changes: 3 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Using LoRA adapters
===================

This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.

LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.

Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with

Expand Down
3 changes: 2 additions & 1 deletion vllm/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Sequence as GenericSequence

import torch
import torch.types

from vllm.utils import is_pin_memory_available

Expand Down Expand Up @@ -64,7 +65,7 @@ def create_dummy_lora_weights(
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
Expand Down
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available

logger = init_logger(__name__)
Expand Down Expand Up @@ -363,7 +364,7 @@ class LoRAModelManager:

def __init__(
self,
model: nn.Module,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
Expand Down Expand Up @@ -411,7 +412,7 @@ def __init__(
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4

self.model: nn.Module = model
self.model = model
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
Expand All @@ -428,7 +429,6 @@ def __init__(
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self

@property
def capacity(self) -> int:
Expand Down
20 changes: 13 additions & 7 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu

Expand Down Expand Up @@ -64,26 +65,31 @@ def _get_quantization_config(


def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig],
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
if hasattr(model_class, "supported_lora_modules"):

if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None:

if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")

extra_kwargs["vision_language_config"] = vision_language_config
extra_kwargs["vlm_config"] = vlm_config

return extra_kwargs


Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
Expand Down Expand Up @@ -292,7 +294,9 @@ def forward(
return hidden_states


class BaiChuanBaseForCausalLM(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
Expand All @@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):

def __init__(
self,
config,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig

from .interfaces import SupportsLoRA


class GLMAttention(nn.Module):

Expand Down Expand Up @@ -322,7 +324,9 @@ def forward(
return hidden_states


class ChatGLMForCausalLM(nn.Module):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
Expand All @@ -345,7 +349,10 @@ def __init__(
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/decilm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Iterable, Optional, Tuple

import torch
from transformers import PretrainedConfig
from transformers import LlamaConfig

from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):

def __init__(
self,
config: Optional[PretrainedConfig] = None,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA

logger = init_logger(__name__)


Expand Down Expand Up @@ -288,7 +290,9 @@ def forward(
return hidden_states


class GemmaForCausalLM(nn.Module):
class GemmaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -319,9 +323,11 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
Expand Down
9 changes: 8 additions & 1 deletion vllm/model_executor/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput

from .interfaces import SupportsLoRA


class GPTBigCodeAttention(nn.Module):

Expand Down Expand Up @@ -230,7 +232,9 @@ def forward(
return hidden_states


class GPTBigCodeForCausalLM(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True

packed_modules_mapping = {"c_attn": ["c_attn"]}

supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
Expand All @@ -250,7 +254,10 @@ def __init__(
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()

self.config = config
self.lora_config = lora_config

self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
Expand Down
130 changes: 130 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)

from typing_extensions import TypeGuard

from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""

supports_vision: ClassVar[Literal[True]]

def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]

def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
...


@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
...


@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
...


def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)

return isinstance(model, SupportsVision)


@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""

supports_lora: ClassVar[Literal[True]]

packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]

# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]

packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]

def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...


@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
...


@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
...


def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
result = _supports_lora(model)

if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))

if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)

return result


def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)

return isinstance(model, SupportsLoRA)
Loading
Loading