Skip to content

[VLM] Add TP support for Phi-4-MM #14453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/offline_inference/audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def run_phi4mm(questions: str, audio_count: int):
enable_lora=True,
max_lora_rank=320,
lora_extra_vocab_size=0,
limit_mm_per_prompt={"audio": audio_count},
)
lora_request = LoRARequest("speech", 1, speech_lora_path)
# To maintain code compatibility in this script, we add LoRA here.
Expand Down
73 changes: 24 additions & 49 deletions vllm/model_executor/models/phi4mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from transformers.utils import logging

from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.inputs.data import TokenInputs, token_inputs
Expand All @@ -34,7 +34,7 @@

from .interfaces import SupportsLoRA, SupportsMultiModal
from .phi4mm_audio import AudioEmbedding
from .utils import maybe_prefix
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
from .vision_siglip_navit import get_siglip_vision_model

# <|endoftext10|> (see vocab.json in hf model)
Expand Down Expand Up @@ -352,12 +352,6 @@ def __init__(self,
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
if hasattr(config, 'embd_pdrop') or hasattr(config, 'embed_pdrop'):
embd_drop = config.embd_pdrop if hasattr(
config, 'embd_pdrop') else config.embed_pdrop
self.drop = nn.Dropout(embd_drop)
else:
self.drop = None

# layer_idx to output the img features
if isinstance(config.img_processor, dict):
Expand Down Expand Up @@ -1431,6 +1425,20 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
],
}

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
"base_layer.": "",
},
orig_to_new_prefix={
"model.embed_tokens_extend.audio_embed.audio_projection.vision.":
"embed_tokens_extend.audio_projection_for_vision.",
"model.embed_tokens_extend.audio_embed.audio_projection.speech.":
"embed_tokens_extend.audio_projection.",
"model.embed_tokens_extend.audio_embed.": "embed_tokens_extend.",
"model.embed_tokens_extend.image_embed.": "vision_encoder.",
},
)

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand All @@ -1445,8 +1453,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.lora_config = lora_config

# Tensor/Pipeline parallel not supported for now.
assert get_tensor_model_parallel_world_size(
) == 1, "tensor parallel is not supported"
assert get_pp_group(
).world_size == 1, "pipeline parallel is not supported"

Expand Down Expand Up @@ -1686,44 +1692,6 @@ def merge_image_features_to_inputs_embeds(
)
return merged_embeds

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = {name: weight for name, weight in weights}
adjusted_weights = {}

for name, weight in weights.items():
# NOTE vision-speech tasks use a separate projection layer
audio_proj_4v = \
"model.embed_tokens_extend.audio_embed.audio_projection.vision"
if name.startswith(audio_proj_4v):
name = name.replace(
audio_proj_4v,
"embed_tokens_extend.audio_projection_for_vision")

name = (name.replace(
"model.embed_tokens_extend.audio_embed."\
"audio_projection.speech.",
"embed_tokens_extend.audio_projection.",
).replace(
"model.embed_tokens_extend.audio_embed.",
"embed_tokens_extend.",
).replace("model.embed_tokens_extend.image_embed.",
"vision_encoder."))
# NOTE: this is deal with LoRA injection, where `base_layer`
# remains as the original layer in the model
if name.endswith(".base_layer.weight"):
name = name.replace(".base_layer.weight", ".weight")
adjusted_weights[name] = weight

missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
logger.debug("*** missing keys:")
for key in missing_keys:
logger.debug(key)
logger.debug("**** unexpected keys:")
for key in unexpected_keys:
logger.debug(key)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -1796,6 +1764,13 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> None:
weights = ((name, data) for name, data in weights
if "lora" not in name)
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
Expand All @@ -1804,4 +1779,4 @@ def get_mm_mapping(self) -> MultiModelKeys:
language_model="model.",
connector=["audio_projection_for_vision", "audio_projection"],
tower_model=["vision_encoder", "embed_tokens_extend"],
)
)
Loading