Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][LoRA]LoRA support added for MiniCPMV2.5 #7199

Merged
merged 25 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add unit test for minicpmv25
  • Loading branch information
jeejeelee committed Sep 27, 2024
commit 99dacdf223d3148ebe4169984012fffc02dd953e
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
99 changes: 99 additions & 0 deletions tests/lora/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List

import pytest

import vllm
from vllm.lora.request import LoRARequest
from vllm.assets.image import ImageAsset


MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"

PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)

IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]


# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=256,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)

inputs = [
{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {"image": asset.pil_image},
}
for asset in IMAGE_ASSETS
]

outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id
else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
)

output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert output1[i] == EXPECTED_OUTPUT[i]
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert output2[i] == EXPECTED_OUTPUT[i]


# @pytest.mark.skip("Requires multiple GPUs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use multi_gpu_test decorator for this. I think LoRA tests currently run on a single GPU though, so you might have to explicitly add this to the distributed tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding this decorator, this test cannot pass locally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error do you get?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the CUDA re-initialization error? Maybe need to move this out to a separate file like how it is with models tests...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the CUDA re-initialization error? Maybe need to move this out to a separate file like how it is with models tests...

Yep. I will handling it asap

@pytest.mark.parametrize("fully_sharded", [True, False])
@pytest.mark.parametrize("tp", [2, 4])
def test_minicpmv_tensor_parallel(minicpmv_lora_files, fully_sharded, tp):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=tp,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)

for i in range(len(EXPECTED_OUTPUT)):
assert output_tp[i] == EXPECTED_OUTPUT[i]

11 changes: 8 additions & 3 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,14 @@ def init_llm(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
# return Qwen2Model(config,
# cache_config=cache_config,
# quant_config=quant_config)

return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config),
name="model")

def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,9 @@ def load_model(self) -> None:
self.model_memory_usage / float(2**30))

if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert supports_lora(
self.model
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
Expand Down
Loading