Skip to content

Commit 1769928

Browse files
[Model] Update Paligemma multimodal processing with PromptUpdate (#14015)
Signed-off-by: Kyle Huang <kylhuang@nvidia.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent ed6ea06 commit 1769928

File tree

4 files changed

+146
-86
lines changed

4 files changed

+146
-86
lines changed

docs/source/models/supported_models.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -842,13 +842,13 @@ See [this page](#generative-models) for more information on how to use generativ
842842
*
843843
* ✅︎
844844
* ✅︎
845-
- * `PaliGemmaForConditionalGeneration`\*
846-
* PaliGemma, PaliGemma 2
845+
- * `PaliGemmaForConditionalGeneration`
846+
* PaliGemma (see note), PaliGemma 2 (see note)
847847
* T + I<sup>E</sup>
848848
* `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc.
849849
*
850850
* ✅︎
851-
*
851+
* ✅︎
852852
- * `Phi3VForCausalLM`
853853
* Phi-3-Vision, Phi-3.5-Vision
854854
* T + I<sup>E+</sup>

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,8 @@
116116
"pixel_values"
117117
),
118118
vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output,
119-
dtype=("half" if current_platform.is_cpu() or current_platform.is_rocm()
120-
else ("half", "float")),
121-
marks=[pytest.mark.core_model],
119+
dtype="bfloat16",
120+
marks=[pytest.mark.skip(reason="vLLM does not support PrefixLM attention mask")], # noqa: E501
122121
),
123122
# TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
124123
# once we upgraded to transformers>=4.49.0.

tests/models/multimodal/processing/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ def _test_processing_correctness(
175175
"Qwen/Qwen2-Audio-7B-Instruct",
176176
"fixie-ai/ultravox-v0_4",
177177
"openai/whisper-large-v3",
178+
"google/paligemma-3b-mix-224",
179+
"google/paligemma2-3b-ft-docci-448",
178180
])
179181
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
180182
@pytest.mark.parametrize("num_batches", [32])

vllm/model_executor/models/paligemma.py

Lines changed: 139 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,26 @@
55

66
import torch
77
from torch import nn
8-
from transformers import PaliGemmaConfig
8+
from transformers import BatchFeature, PaliGemmaConfig
99

1010
from vllm.config import VllmConfig
11-
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
12-
InputContext, token_inputs)
1311
from vllm.logger import init_logger
1412
from vllm.model_executor.layers.sampler import SamplerOutput
1513
from vllm.model_executor.sampling_metadata import SamplingMetadata
1614
from vllm.multimodal import MULTIMODAL_REGISTRY
17-
from vllm.multimodal.inputs import NestedTensors
15+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
16+
MultiModalInputs, MultiModalKwargs,
17+
NestedTensors)
18+
from vllm.multimodal.parse import MultiModalDataItems
19+
from vllm.multimodal.processing import (BaseMultiModalProcessor,
20+
BaseProcessingInfo, PromptIndexTargets,
21+
PromptInsertion, PromptReplacement,
22+
PromptUpdateDetails)
23+
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
1824
from vllm.sequence import IntermediateTensors
19-
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
2025

21-
from .interfaces import SupportsMultiModal, SupportsPP, SupportsV0Only
22-
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
23-
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
26+
from .interfaces import SupportsMultiModal, SupportsPP
27+
from .siglip import SiglipVisionModel, get_max_siglip_image_tokens
2428
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
2529
maybe_prefix, merge_multimodal_embeddings)
2630

@@ -46,97 +50,152 @@ class PaliGemmaImageEmbeddingInputs(TypedDict):
4650
PaliGemmaImageEmbeddingInputs]
4751

4852

49-
def get_max_paligemma_image_tokens(ctx: InputContext):
50-
hf_config = ctx.get_hf_config(PaliGemmaConfig)
51-
vision_config = hf_config.vision_config
52-
53-
return get_max_siglip_image_tokens(vision_config)
54-
55-
56-
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
57-
mm_counts: Mapping[str, int]):
58-
hf_config = ctx.get_hf_config(PaliGemmaConfig)
59-
vision_config = hf_config.vision_config
60-
num_images = mm_counts["image"]
61-
62-
seq_data, ranges = dummy_seq_data_for_siglip(
63-
vision_config,
64-
seq_len,
65-
num_images,
66-
image_token_id=hf_config.image_token_index,
67-
)
68-
69-
mm_data = dummy_image_for_siglip(vision_config, num_images)
70-
return DummyData(seq_data, mm_data, ranges)
71-
72-
73-
def input_processor_for_paligemma(ctx: InputContext,
74-
inputs: DecoderOnlyInputs):
53+
class PaliGemmaMultiModalProjector(nn.Module):
7554

76-
"""
77-
The correct prompt format needs to be:
78-
'<image>' * image_feature_size + '<bos>' + prompt + '\n'
55+
def __init__(self, vision_hidden_size: int, projection_dim: int):
56+
super().__init__()
7957

80-
See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
81-
""" # noqa
58+
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
8259

83-
multi_modal_data = inputs.get("multi_modal_data")
84-
if multi_modal_data is None or "image" not in multi_modal_data:
85-
return inputs
60+
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
61+
hidden_states = self.linear(image_features)
62+
return hidden_states
8663

87-
model_config = ctx.model_config
88-
hf_config = ctx.get_hf_config(PaliGemmaConfig)
8964

90-
tokenizer = cached_tokenizer_from_config(model_config)
91-
image_feature_size = hf_config.text_config.num_image_tokens
92-
image_token_str = tokenizer.decode(hf_config.image_token_index)
93-
bos_token = tokenizer.decode(hf_config.bos_token_id)
94-
image_token_str_pad = image_token_str * image_feature_size
95-
image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
65+
class PaliGemmaProcessingInfo(BaseProcessingInfo):
9666

97-
orig_prompt = inputs.get("prompt")
98-
orig_prompt_ids = inputs.get("prompt_token_ids")
67+
def get_hf_config(self):
68+
return self.ctx.get_hf_config(PaliGemmaConfig)
9969

100-
if orig_prompt is not None and image_token_str in orig_prompt:
101-
logger.warning(
102-
"The image token '%s' was detected in the prompt and "
103-
"will be removed. Please follow the proper prompt format"
104-
" documented on HuggingFace.", image_token_str)
105-
orig_prompt = orig_prompt.replace(image_token_str, "")
106-
orig_prompt_ids.remove(hf_config.image_token_index)
70+
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
71+
return {"image": 1}
10772

108-
new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
73+
def get_mm_max_tokens_per_item(
74+
self,
75+
seq_len: int,
76+
mm_counts: Mapping[str, int],
77+
) -> Mapping[str, int]:
78+
return {"image": self.get_num_image_tokens()}
10979

110-
# The PaliGemma 2 tokenizer does not include a starting BOS token
111-
if orig_prompt_ids[0] != hf_config.bos_token_id:
112-
orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
80+
def get_num_image_tokens(self) -> int:
81+
hf_config = self.get_hf_config()
82+
vision_config = hf_config.vision_config
83+
return get_max_siglip_image_tokens(vision_config)
11384

114-
new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
11585

116-
# NOTE: Create a defensive copy of the original inputs
117-
return token_inputs(prompt_token_ids=new_token_ids,
118-
prompt=new_prompt,
119-
multi_modal_data=multi_modal_data)
86+
class PaliGemmaDummyInputsBuilder(
87+
BaseDummyInputsBuilder[PaliGemmaProcessingInfo]):
12088

89+
def get_dummy_processor_inputs(
90+
self,
91+
seq_len: int,
92+
mm_counts: Mapping[str, int],
93+
) -> ProcessorInputs:
94+
hf_config = self.info.get_hf_config()
95+
vision_config = hf_config.vision_config
96+
max_image_size = vision_config.image_size
97+
98+
num_images = mm_counts.get("image", 0)
99+
100+
mm_data = {
101+
"image":
102+
self._get_dummy_images(width=max_image_size,
103+
height=max_image_size,
104+
num_images=num_images)
105+
}
106+
107+
return ProcessorInputs(
108+
prompt_text="",
109+
mm_data=mm_data,
110+
)
121111

122-
class PaliGemmaMultiModalProjector(nn.Module):
123112

124-
def __init__(self, vision_hidden_size: int, projection_dim: int):
125-
super().__init__()
113+
class PaliGemmaMultiModalProcessor(
114+
BaseMultiModalProcessor[PaliGemmaProcessingInfo]):
126115

127-
self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
116+
def _call_hf_processor(
117+
self,
118+
prompt: str,
119+
mm_data: Mapping[str, object],
120+
mm_kwargs: Mapping[str, object],
121+
) -> BatchFeature:
122+
tokenizer = self.info.get_tokenizer()
123+
if not mm_data:
124+
prompt_ids = tokenizer.encode(prompt)
125+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
126+
127+
return super()._call_hf_processor(
128+
prompt=prompt,
129+
mm_data=mm_data,
130+
mm_kwargs=mm_kwargs,
131+
)
128132

129-
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
130-
hidden_states = self.linear(image_features)
131-
return hidden_states
133+
def _get_mm_fields_config(
134+
self,
135+
hf_inputs: BatchFeature,
136+
hf_processor_mm_kwargs: Mapping[str, object],
137+
) -> Mapping[str, MultiModalFieldConfig]:
138+
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
132139

140+
def _get_prompt_updates(
141+
self,
142+
mm_items: MultiModalDataItems,
143+
hf_processor_mm_kwargs: Mapping[str, object],
144+
out_mm_kwargs: MultiModalKwargs,
145+
) -> list[PromptReplacement]:
146+
hf_config = self.info.get_hf_config()
147+
image_token_id = hf_config.image_token_index
148+
149+
tokenizer = self.info.get_tokenizer()
150+
num_image_tokens = self.info.get_num_image_tokens()
151+
image_tokens = [image_token_id] * num_image_tokens
152+
153+
bos_token_id = tokenizer.bos_token_id
154+
assert isinstance(bos_token_id, int)
155+
156+
# Paligemma 1 and 2 have different tokenizer.add_bos_token
157+
# Insert <image>*n + <bos> after <bos> for Paligemma 1
158+
# Insert <image>*n + <bos> for Paligemma 2
159+
return [
160+
PromptInsertion(
161+
modality="image",
162+
target=PromptIndexTargets.prefix(
163+
[bos_token_id] if tokenizer.add_bos_token else []),
164+
insertion=PromptUpdateDetails(
165+
full=image_tokens + [bos_token_id],
166+
features=image_tokens,
167+
),
168+
)
169+
]
133170

134-
@MULTIMODAL_REGISTRY.register_image_input_mapper()
135-
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
136-
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
137-
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
171+
def apply(
172+
self,
173+
prompt: Union[str, list[int]],
174+
mm_data: MultiModalDataDict,
175+
hf_processor_mm_kwargs: Mapping[str, object],
176+
) -> MultiModalInputs:
177+
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
178+
prompt_token_ids = mm_inputs["prompt_token_ids"]
179+
180+
tokenizer = self.info.get_tokenizer()
181+
newline_prompt = "\n"
182+
newline_token_id = tokenizer.encode(newline_prompt)[-1] # 108
183+
# Force to add newline at the end of prompt for paligemma's format
184+
# This step can NOT be replacemented by current PromptUpdate methods
185+
if len(prompt_token_ids) and prompt_token_ids[-1] != newline_token_id:
186+
prompt_token_ids.append(newline_token_id)
187+
mm_inputs["prompt_token_ids"] = prompt_token_ids
188+
mm_inputs["prompt"] += newline_prompt
189+
190+
return mm_inputs
191+
192+
193+
@MULTIMODAL_REGISTRY.register_processor(
194+
PaliGemmaMultiModalProcessor,
195+
info=PaliGemmaProcessingInfo,
196+
dummy_inputs=PaliGemmaDummyInputsBuilder)
138197
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
139-
SupportsPP, SupportsV0Only):
198+
SupportsPP):
140199
packed_modules_mapping = {
141200
"qkv_proj": [
142201
"q_proj",

0 commit comments

Comments
 (0)