Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MiniCPMV2.5 #7199

Merged
merged 25 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Modify code
  • Loading branch information
jeejeelee committed Sep 27, 2024
commit bf4ee9d6d13c9fa876ec39e5b09d35bde712660b
4 changes: 3 additions & 1 deletion tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
Expand Down Expand Up @@ -73,7 +75,7 @@ def test_minicpmv_lora(minicpmv_lora_files):
assert output2[i] == EXPECTED_OUTPUT[i]


# @pytest.mark.skip("Requires multiple GPUs")
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
@pytest.mark.parametrize("tp", [2, 4])
def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp):
Expand Down
6 changes: 3 additions & 3 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _create_lora_modules(self):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_module(module_name):
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
Expand Down Expand Up @@ -501,7 +501,7 @@ def create_dummy_lora(
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_module(module_name)):
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
Expand Down Expand Up @@ -562,7 +562,7 @@ def _match_target_modules(self, module_name: str):
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)

def _filter_unsupported_module(self, module_name: str) -> bool:
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
Expand Down
22 changes: 1 addition & 21 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
Expand Down Expand Up @@ -390,24 +391,6 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return MultiModalInputs(batch_data)


class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""

def __init__(self, llm: nn.Module, name: str) -> None:
super().__init__()
self.model_name = name
setattr(self, name, llm)

def forward(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name)(*args, **kwargs)

def embed_tokens(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)


class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
Expand Down Expand Up @@ -904,9 +887,6 @@ def init_llm(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
# return Qwen2Model(config,
# cache_config=cache_config,
# quant_config=quant_config)

return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
Expand Down
22 changes: 20 additions & 2 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
Tuple, Union, overload)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -329,3 +329,21 @@ def make_empty_intermediate_tensors(
})

return make_empty_intermediate_tensors


class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""

def __init__(self, llm: nn.Module, name: str) -> None:
super().__init__()
self.model_name = name
setattr(self, name, llm)

def forward(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name)(*args, **kwargs)

def embed_tokens(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
Loading