Skip to content

[Model] Update Paligemma multimodal processing with PromptUpdate #14015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `PaliGemmaForConditionalGeneration`\*
* PaliGemma, PaliGemma 2
- * `PaliGemmaForConditionalGeneration`
* PaliGemma (see note), PaliGemma 2 (see note)
* T + I<sup>E</sup>
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
*
* ✅︎
*
* ✅︎
- * `Phi3VForCausalLM`
* Phi-3-Vision, Phi-3.5-Vision
* T + I<sup>E+</sup>
Expand Down
5 changes: 2 additions & 3 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@
"pixel_values"
),
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
else ("half", "float")),
marks=[pytest.mark.core_model],
dtype="bfloat16",
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
),
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
# once we upgraded to transformers>=4.49.0.
Expand Down
2 changes: 2 additions & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _test_processing_correctness(
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_4",
"openai/whisper-large-v3",
"google/paligemma-3b-mix-224",
"google/paligemma2-3b-ft-docci-448",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
Expand Down
219 changes: 139 additions & 80 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@

import torch
from torch import nn
from transformers import PaliGemmaConfig
from transformers import BatchFeature, PaliGemmaConfig

from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptIndexTargets,
PromptInsertion, PromptReplacement,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config

from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)

Expand All @@ -46,97 +50,152 @@ class PaliGemmaImageEmbeddingInputs(TypedDict):
PaliGemmaImageEmbeddingInputs]


def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config

return get_max_siglip_image_tokens(vision_config)


def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]

seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
)

mm_data = dummy_image_for_siglip(vision_config, num_images)
return DummyData(seq_data, mm_data, ranges)


def input_processor_for_paligemma(ctx: InputContext,
inputs: DecoderOnlyInputs):
class PaliGemmaMultiModalProjector(nn.Module):

"""
The correct prompt format needs to be:
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()

See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
""" # noqa
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)

multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states

model_config = ctx.model_config
hf_config = ctx.get_hf_config(PaliGemmaConfig)

tokenizer = cached_tokenizer_from_config(model_config)
image_feature_size = hf_config.text_config.num_image_tokens
image_token_str = tokenizer.decode(hf_config.image_token_index)
bos_token = tokenizer.decode(hf_config.bos_token_id)
image_token_str_pad = image_token_str * image_feature_size
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
class PaliGemmaProcessingInfo(BaseProcessingInfo):

orig_prompt = inputs.get("prompt")
orig_prompt_ids = inputs.get("prompt_token_ids")
def get_hf_config(self):
return self.ctx.get_hf_config(PaliGemmaConfig)

if orig_prompt is not None and image_token_str in orig_prompt:
logger.warning(
"The image token '%s' was detected in the prompt and "
"will be removed. Please follow the proper prompt format"
" documented on HuggingFace.", image_token_str)
orig_prompt = orig_prompt.replace(image_token_str, "")
orig_prompt_ids.remove(hf_config.image_token_index)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}

new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}

# The PaliGemma 2 tokenizer does not include a starting BOS token
if orig_prompt_ids[0] != hf_config.bos_token_id:
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
def get_num_image_tokens(self) -> int:
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
return get_max_siglip_image_tokens(vision_config)

new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
class PaliGemmaDummyInputsBuilder(
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):

def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
max_image_size = vision_config.image_size

num_images = mm_counts.get("image", 0)

mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}

return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)

class PaliGemmaMultiModalProjector(nn.Module):

def __init__(self, vision_hidden_size: int, projection_dim: int):
super().__init__()
class PaliGemmaMultiModalProcessor(
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):

self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if not mm_data:
prompt_ids = tokenizer.encode(prompt)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear(image_features)
return hidden_states
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))

def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

tokenizer = self.info.get_tokenizer()
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens

bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)

# Paligemma 1 and 2 have different tokenizer.add_bos_token
# Insert <image>*n + <bos> after <bos> for Paligemma 1
# Insert <image>*n + <bos> for Paligemma 2
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix(
[bos_token_id] if tokenizer.add_bos_token else []),
insertion=PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
)
]

@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
prompt_token_ids = mm_inputs["prompt_token_ids"]

tokenizer = self.info.get_tokenizer()
newline_prompt = "\n"
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
# Force to add newline at the end of prompt for paligemma's format
# This step can NOT be replacemented by current PromptUpdate methods
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
prompt_token_ids.append(newline_token_id)
mm_inputs["prompt_token_ids"] = prompt_token_ids
mm_inputs["prompt"] += newline_prompt

return mm_inputs


@MULTIMODAL_REGISTRY.register_processor(
PaliGemmaMultiModalProcessor,
info=PaliGemmaProcessingInfo,
dummy_inputs=PaliGemmaDummyInputsBuilder)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsV0Only):
SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down