5
5
6
6
import torch
7
7
from torch import nn
8
- from transformers import PaliGemmaConfig
8
+ from transformers import BatchFeature , PaliGemmaConfig
9
9
10
10
from vllm .config import VllmConfig
11
- from vllm .inputs import (INPUT_REGISTRY , DecoderOnlyInputs , DummyData ,
12
- InputContext , token_inputs )
13
11
from vllm .logger import init_logger
14
12
from vllm .model_executor .layers .sampler import SamplerOutput
15
13
from vllm .model_executor .sampling_metadata import SamplingMetadata
16
14
from vllm .multimodal import MULTIMODAL_REGISTRY
17
- from vllm .multimodal .inputs import NestedTensors
15
+ from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
16
+ MultiModalInputs , MultiModalKwargs ,
17
+ NestedTensors )
18
+ from vllm .multimodal .parse import MultiModalDataItems
19
+ from vllm .multimodal .processing import (BaseMultiModalProcessor ,
20
+ BaseProcessingInfo , PromptIndexTargets ,
21
+ PromptInsertion , PromptReplacement ,
22
+ PromptUpdateDetails )
23
+ from vllm .multimodal .profiling import BaseDummyInputsBuilder , ProcessorInputs
18
24
from vllm .sequence import IntermediateTensors
19
- from vllm .transformers_utils .tokenizer import cached_tokenizer_from_config
20
25
21
- from .interfaces import SupportsMultiModal , SupportsPP , SupportsV0Only
22
- from .siglip import (SiglipVisionModel , dummy_image_for_siglip ,
23
- dummy_seq_data_for_siglip , get_max_siglip_image_tokens )
26
+ from .interfaces import SupportsMultiModal , SupportsPP
27
+ from .siglip import SiglipVisionModel , get_max_siglip_image_tokens
24
28
from .utils import (AutoWeightsLoader , init_vllm_registered_model ,
25
29
maybe_prefix , merge_multimodal_embeddings )
26
30
@@ -46,97 +50,152 @@ class PaliGemmaImageEmbeddingInputs(TypedDict):
46
50
PaliGemmaImageEmbeddingInputs ]
47
51
48
52
49
- def get_max_paligemma_image_tokens (ctx : InputContext ):
50
- hf_config = ctx .get_hf_config (PaliGemmaConfig )
51
- vision_config = hf_config .vision_config
52
-
53
- return get_max_siglip_image_tokens (vision_config )
54
-
55
-
56
- def dummy_data_for_paligemma (ctx : InputContext , seq_len : int ,
57
- mm_counts : Mapping [str , int ]):
58
- hf_config = ctx .get_hf_config (PaliGemmaConfig )
59
- vision_config = hf_config .vision_config
60
- num_images = mm_counts ["image" ]
61
-
62
- seq_data , ranges = dummy_seq_data_for_siglip (
63
- vision_config ,
64
- seq_len ,
65
- num_images ,
66
- image_token_id = hf_config .image_token_index ,
67
- )
68
-
69
- mm_data = dummy_image_for_siglip (vision_config , num_images )
70
- return DummyData (seq_data , mm_data , ranges )
71
-
72
-
73
- def input_processor_for_paligemma (ctx : InputContext ,
74
- inputs : DecoderOnlyInputs ):
53
+ class PaliGemmaMultiModalProjector (nn .Module ):
75
54
76
- """
77
- The correct prompt format needs to be:
78
- '<image>' * image_feature_size + '<bos>' + prompt + '\n '
55
+ def __init__ (self , vision_hidden_size : int , projection_dim : int ):
56
+ super ().__init__ ()
79
57
80
- See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
81
- """ # noqa
58
+ self .linear = nn .Linear (vision_hidden_size , projection_dim , bias = True )
82
59
83
- multi_modal_data = inputs . get ( "multi_modal_data" )
84
- if multi_modal_data is None or "image" not in multi_modal_data :
85
- return inputs
60
+ def forward ( self , image_features : torch . Tensor ) -> torch . Tensor :
61
+ hidden_states = self . linear ( image_features )
62
+ return hidden_states
86
63
87
- model_config = ctx .model_config
88
- hf_config = ctx .get_hf_config (PaliGemmaConfig )
89
64
90
- tokenizer = cached_tokenizer_from_config (model_config )
91
- image_feature_size = hf_config .text_config .num_image_tokens
92
- image_token_str = tokenizer .decode (hf_config .image_token_index )
93
- bos_token = tokenizer .decode (hf_config .bos_token_id )
94
- image_token_str_pad = image_token_str * image_feature_size
95
- image_token_ids_pad = [hf_config .image_token_index ] * image_feature_size
65
+ class PaliGemmaProcessingInfo (BaseProcessingInfo ):
96
66
97
- orig_prompt = inputs . get ( "prompt" )
98
- orig_prompt_ids = inputs . get ( "prompt_token_ids" )
67
+ def get_hf_config ( self ):
68
+ return self . ctx . get_hf_config ( PaliGemmaConfig )
99
69
100
- if orig_prompt is not None and image_token_str in orig_prompt :
101
- logger .warning (
102
- "The image token '%s' was detected in the prompt and "
103
- "will be removed. Please follow the proper prompt format"
104
- " documented on HuggingFace." , image_token_str )
105
- orig_prompt = orig_prompt .replace (image_token_str , "" )
106
- orig_prompt_ids .remove (hf_config .image_token_index )
70
+ def get_supported_mm_limits (self ) -> Mapping [str , Optional [int ]]:
71
+ return {"image" : 1 }
107
72
108
- new_prompt = f"{ image_token_str_pad } { bos_token } { orig_prompt } \n "
73
+ def get_mm_max_tokens_per_item (
74
+ self ,
75
+ seq_len : int ,
76
+ mm_counts : Mapping [str , int ],
77
+ ) -> Mapping [str , int ]:
78
+ return {"image" : self .get_num_image_tokens ()}
109
79
110
- # The PaliGemma 2 tokenizer does not include a starting BOS token
111
- if orig_prompt_ids [0 ] != hf_config .bos_token_id :
112
- orig_prompt_ids = [hf_config .bos_token_id ] + orig_prompt_ids
80
+ def get_num_image_tokens (self ) -> int :
81
+ hf_config = self .get_hf_config ()
82
+ vision_config = hf_config .vision_config
83
+ return get_max_siglip_image_tokens (vision_config )
113
84
114
- new_token_ids = image_token_ids_pad + orig_prompt_ids + [108 ] #newline
115
85
116
- # NOTE: Create a defensive copy of the original inputs
117
- return token_inputs (prompt_token_ids = new_token_ids ,
118
- prompt = new_prompt ,
119
- multi_modal_data = multi_modal_data )
86
+ class PaliGemmaDummyInputsBuilder (
87
+ BaseDummyInputsBuilder [PaliGemmaProcessingInfo ]):
120
88
89
+ def get_dummy_processor_inputs (
90
+ self ,
91
+ seq_len : int ,
92
+ mm_counts : Mapping [str , int ],
93
+ ) -> ProcessorInputs :
94
+ hf_config = self .info .get_hf_config ()
95
+ vision_config = hf_config .vision_config
96
+ max_image_size = vision_config .image_size
97
+
98
+ num_images = mm_counts .get ("image" , 0 )
99
+
100
+ mm_data = {
101
+ "image" :
102
+ self ._get_dummy_images (width = max_image_size ,
103
+ height = max_image_size ,
104
+ num_images = num_images )
105
+ }
106
+
107
+ return ProcessorInputs (
108
+ prompt_text = "" ,
109
+ mm_data = mm_data ,
110
+ )
121
111
122
- class PaliGemmaMultiModalProjector (nn .Module ):
123
112
124
- def __init__ ( self , vision_hidden_size : int , projection_dim : int ):
125
- super (). __init__ ()
113
+ class PaliGemmaMultiModalProcessor (
114
+ BaseMultiModalProcessor [ PaliGemmaProcessingInfo ]):
126
115
127
- self .linear = nn .Linear (vision_hidden_size , projection_dim , bias = True )
116
+ def _call_hf_processor (
117
+ self ,
118
+ prompt : str ,
119
+ mm_data : Mapping [str , object ],
120
+ mm_kwargs : Mapping [str , object ],
121
+ ) -> BatchFeature :
122
+ tokenizer = self .info .get_tokenizer ()
123
+ if not mm_data :
124
+ prompt_ids = tokenizer .encode (prompt )
125
+ return BatchFeature (dict (input_ids = [prompt_ids ]), tensor_type = "pt" )
126
+
127
+ return super ()._call_hf_processor (
128
+ prompt = prompt ,
129
+ mm_data = mm_data ,
130
+ mm_kwargs = mm_kwargs ,
131
+ )
128
132
129
- def forward (self , image_features : torch .Tensor ) -> torch .Tensor :
130
- hidden_states = self .linear (image_features )
131
- return hidden_states
133
+ def _get_mm_fields_config (
134
+ self ,
135
+ hf_inputs : BatchFeature ,
136
+ hf_processor_mm_kwargs : Mapping [str , object ],
137
+ ) -> Mapping [str , MultiModalFieldConfig ]:
138
+ return dict (pixel_values = MultiModalFieldConfig .batched ("image" ))
132
139
140
+ def _get_prompt_updates (
141
+ self ,
142
+ mm_items : MultiModalDataItems ,
143
+ hf_processor_mm_kwargs : Mapping [str , object ],
144
+ out_mm_kwargs : MultiModalKwargs ,
145
+ ) -> list [PromptReplacement ]:
146
+ hf_config = self .info .get_hf_config ()
147
+ image_token_id = hf_config .image_token_index
148
+
149
+ tokenizer = self .info .get_tokenizer ()
150
+ num_image_tokens = self .info .get_num_image_tokens ()
151
+ image_tokens = [image_token_id ] * num_image_tokens
152
+
153
+ bos_token_id = tokenizer .bos_token_id
154
+ assert isinstance (bos_token_id , int )
155
+
156
+ # Paligemma 1 and 2 have different tokenizer.add_bos_token
157
+ # Insert <image>*n + <bos> after <bos> for Paligemma 1
158
+ # Insert <image>*n + <bos> for Paligemma 2
159
+ return [
160
+ PromptInsertion (
161
+ modality = "image" ,
162
+ target = PromptIndexTargets .prefix (
163
+ [bos_token_id ] if tokenizer .add_bos_token else []),
164
+ insertion = PromptUpdateDetails (
165
+ full = image_tokens + [bos_token_id ],
166
+ features = image_tokens ,
167
+ ),
168
+ )
169
+ ]
133
170
134
- @MULTIMODAL_REGISTRY .register_image_input_mapper ()
135
- @MULTIMODAL_REGISTRY .register_max_image_tokens (get_max_paligemma_image_tokens )
136
- @INPUT_REGISTRY .register_dummy_data (dummy_data_for_paligemma )
137
- @INPUT_REGISTRY .register_input_processor (input_processor_for_paligemma )
171
+ def apply (
172
+ self ,
173
+ prompt : Union [str , list [int ]],
174
+ mm_data : MultiModalDataDict ,
175
+ hf_processor_mm_kwargs : Mapping [str , object ],
176
+ ) -> MultiModalInputs :
177
+ mm_inputs = super ().apply (prompt , mm_data , hf_processor_mm_kwargs )
178
+ prompt_token_ids = mm_inputs ["prompt_token_ids" ]
179
+
180
+ tokenizer = self .info .get_tokenizer ()
181
+ newline_prompt = "\n "
182
+ newline_token_id = tokenizer .encode (newline_prompt )[- 1 ] # 108
183
+ # Force to add newline at the end of prompt for paligemma's format
184
+ # This step can NOT be replacemented by current PromptUpdate methods
185
+ if len (prompt_token_ids ) and prompt_token_ids [- 1 ] != newline_token_id :
186
+ prompt_token_ids .append (newline_token_id )
187
+ mm_inputs ["prompt_token_ids" ] = prompt_token_ids
188
+ mm_inputs ["prompt" ] += newline_prompt
189
+
190
+ return mm_inputs
191
+
192
+
193
+ @MULTIMODAL_REGISTRY .register_processor (
194
+ PaliGemmaMultiModalProcessor ,
195
+ info = PaliGemmaProcessingInfo ,
196
+ dummy_inputs = PaliGemmaDummyInputsBuilder )
138
197
class PaliGemmaForConditionalGeneration (nn .Module , SupportsMultiModal ,
139
- SupportsPP , SupportsV0Only ):
198
+ SupportsPP ):
140
199
packed_modules_mapping = {
141
200
"qkv_proj" : [
142
201
"q_proj" ,
0 commit comments