Skip to content

Commit f6ab5fc

Browse files
jeejeeleelulmer
authored andcommitted
[Misc] Add Phi4-MM example (vllm-project#14343)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent df06bda commit f6ab5fc

File tree

4 files changed

+131
-7
lines changed

4 files changed

+131
-7
lines changed

examples/offline_inference/audio_language.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
For most models, the prompt format should follow corresponding examples
77
on HuggingFace model repository.
88
"""
9+
import os
10+
11+
from huggingface_hub import snapshot_download
912
from transformers import AutoTokenizer
1013

1114
from vllm import LLM, SamplingParams
1215
from vllm.assets.audio import AudioAsset
16+
from vllm.lora.request import LoRARequest
1317
from vllm.utils import FlexibleArgumentParser
1418

1519
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
@@ -51,6 +55,39 @@ def run_minicpmo(question: str, audio_count: int):
5155
return llm, prompt, stop_token_ids
5256

5357

58+
# Phi-4-multimodal-instruct
59+
def run_phi4mm(questions: str, audio_count: int):
60+
"""
61+
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
62+
show how to process audio inputs.
63+
"""
64+
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
65+
# Since the vision-lora and speech-lora co-exist with the base model,
66+
# we have to manually specify the path of the lora weights.
67+
speech_lora_path = os.path.join(model_path, "speech-lora")
68+
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
69+
70+
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
71+
72+
llm = LLM(
73+
model=model_path,
74+
trust_remote_code=True,
75+
max_model_len=4096,
76+
max_num_seqs=2,
77+
enable_lora=True,
78+
max_lora_rank=320,
79+
lora_extra_vocab_size=0,
80+
)
81+
lora_request = LoRARequest("speech", 1, speech_lora_path)
82+
# To maintain code compatibility in this script, we add LoRA here.
83+
llm.llm_engine.add_lora(lora_request=lora_request)
84+
# You can also add LoRA using:
85+
# llm.generate(prompts, lora_request=lora_request,...)
86+
87+
stop_token_ids = None
88+
return llm, prompts, stop_token_ids
89+
90+
5491
# Qwen2-Audio
5592
def run_qwen2_audio(question: str, audio_count: int):
5693
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
@@ -113,6 +150,7 @@ def run_whisper(question: str, audio_count: int):
113150

114151
model_example_map = {
115152
"minicpmo": run_minicpmo,
153+
"phi4_mm": run_phi4mm,
116154
"qwen2_audio": run_qwen2_audio,
117155
"ultravox": run_ultravox,
118156
"whisper": run_whisper,

examples/offline_inference/vision_language.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
For most models, the prompt format should follow corresponding examples
77
on HuggingFace model repository.
88
"""
9+
import os
910
import random
1011

12+
from huggingface_hub import snapshot_download
1113
from transformers import AutoTokenizer
1214

1315
from vllm import LLM, SamplingParams
1416
from vllm.assets.image import ImageAsset
1517
from vllm.assets.video import VideoAsset
18+
from vllm.lora.request import LoRARequest
1619
from vllm.utils import FlexibleArgumentParser
1720

1821
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
@@ -519,6 +522,40 @@ def run_phi3v(questions: list[str], modality: str):
519522
return llm, prompts, stop_token_ids
520523

521524

525+
# Phi-4-multimodal-instruct
526+
def run_phi4mm(questions: list[str], modality: str):
527+
"""
528+
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
529+
show how to process image inputs.
530+
"""
531+
assert modality == "image"
532+
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
533+
# Since the vision-lora and speech-lora co-exist with the base model,
534+
# we have to manually specify the path of the lora weights.
535+
vision_lora_path = os.path.join(model_path, "vision-lora")
536+
prompts = [
537+
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
538+
for question in questions
539+
]
540+
llm = LLM(
541+
model=model_path,
542+
trust_remote_code=True,
543+
max_model_len=4096,
544+
max_num_seqs=2,
545+
enable_lora=True,
546+
max_lora_rank=320,
547+
lora_extra_vocab_size=0,
548+
)
549+
lora_request = LoRARequest("vision", 1, vision_lora_path)
550+
# To maintain code compatibility in this script, we add LoRA here.
551+
llm.llm_engine.add_lora(lora_request=lora_request)
552+
# You can also add LoRA using:
553+
# llm.generate(prompts, lora_request=lora_request,...)
554+
555+
stop_token_ids = None
556+
return llm, prompts, stop_token_ids
557+
558+
522559
# Pixtral HF-format
523560
def run_pixtral_hf(questions: list[str], modality: str):
524561
assert modality == "image"
@@ -644,6 +681,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str):
644681
"paligemma": run_paligemma,
645682
"paligemma2": run_paligemma2,
646683
"phi3_v": run_phi3v,
684+
"phi4_mm": run_phi4mm,
647685
"pixtral_hf": run_pixtral_hf,
648686
"qwen_vl": run_qwen_vl,
649687
"qwen2_vl": run_qwen2_vl,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
multi-image input on vision language models for text generation,
55
using the chat template defined by the model.
66
"""
7+
import os
78
from argparse import Namespace
89
from typing import NamedTuple, Optional
910

11+
from huggingface_hub import snapshot_download
1012
from PIL.Image import Image
1113
from transformers import AutoProcessor, AutoTokenizer
1214

1315
from vllm import LLM, SamplingParams
16+
from vllm.lora.request import LoRARequest
1417
from vllm.multimodal.utils import fetch_image
1518
from vllm.utils import FlexibleArgumentParser
1619

@@ -294,6 +297,46 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
294297
)
295298

296299

300+
def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
301+
"""
302+
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
303+
show how to process multi images inputs.
304+
"""
305+
306+
model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
307+
# Since the vision-lora and speech-lora co-exist with the base model,
308+
# we have to manually specify the path of the lora weights.
309+
vision_lora_path = os.path.join(model_path, "vision-lora")
310+
llm = LLM(
311+
model=model_path,
312+
trust_remote_code=True,
313+
max_model_len=10000,
314+
max_num_seqs=2,
315+
limit_mm_per_prompt={"image": len(image_urls)},
316+
enable_lora=True,
317+
max_lora_rank=320,
318+
lora_extra_vocab_size=0,
319+
)
320+
lora_request = LoRARequest("vision", 1, vision_lora_path)
321+
# To maintain code compatibility in this script, we add LoRA here.
322+
llm.llm_engine.add_lora(lora_request=lora_request)
323+
# You can also add LoRA using:
324+
# llm.generate(prompts, lora_request=lora_request,...)
325+
326+
placeholders = "".join(f"<|image_{i}|>"
327+
for i, _ in enumerate(image_urls, start=1))
328+
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
329+
stop_token_ids = None
330+
331+
return ModelRequestData(
332+
llm=llm,
333+
prompt=prompt,
334+
stop_token_ids=stop_token_ids,
335+
image_data=[fetch_image(url) for url in image_urls],
336+
chat_template=None,
337+
)
338+
339+
297340
def load_qwen_vl_chat(question: str,
298341
image_urls: list[str]) -> ModelRequestData:
299342
model_name = "Qwen/Qwen-VL-Chat"
@@ -459,6 +502,7 @@ def load_qwen2_5_vl(question, image_urls: list[str]) -> ModelRequestData:
459502
"mllama": load_mllama,
460503
"NVLM_D": load_nvlm_d,
461504
"phi3_v": load_phi3v,
505+
"phi4_mm": load_phi4mm,
462506
"pixtral_hf": load_pixtral_hf,
463507
"qwen_vl_chat": load_qwen_vl_chat,
464508
"qwen2_vl": load_qwen2_vl,

vllm/model_executor/models/phi4mm.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from vllm.model_executor.layers.vocab_parallel_embedding import (
2626
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
2727
from vllm.model_executor.models.llama import LlamaModel
28+
from vllm.model_executor.models.module_mapping import MultiModelKeys
2829
from vllm.model_executor.sampling_metadata import SamplingMetadata
2930
from vllm.multimodal import MULTIMODAL_REGISTRY
3031
from vllm.multimodal.inputs import MultiModalInputs, NestedTensors
@@ -1421,7 +1422,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
14211422
"""
14221423
Implements the Phi-4-multimodal-instruct model in VLLM.
14231424
"""
1424-
# LoRA specific attributes
14251425
packed_modules_mapping = {
14261426
"qkv_proj": [
14271427
"qkv_proj",
@@ -1430,12 +1430,6 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
14301430
"gate_up_proj",
14311431
],
14321432
}
1433-
supported_lora_modules = [
1434-
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
1435-
]
1436-
# Phi4MMForCausalLM does not apply LoRA to the embedding layer.
1437-
embedding_modules = {}
1438-
embedding_padding_modules = []
14391433

14401434
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14411435
super().__init__()
@@ -1801,3 +1795,13 @@ def sample(
18011795
) -> Optional[SamplerOutput]:
18021796
next_tokens = self.sampler(logits, sampling_metadata)
18031797
return next_tokens
1798+
1799+
def get_mm_mapping(self) -> MultiModelKeys:
1800+
"""
1801+
Get the module prefix in multimodal models
1802+
"""
1803+
return MultiModelKeys.from_string_field(
1804+
language_model="model.",
1805+
connector=["audio_projection_for_vision", "audio_projection"],
1806+
tower_model=["vision_encoder", "embed_tokens_extend"],
1807+
)

0 commit comments

Comments
 (0)