Skip to content

Commit 11500ca

Browse files
DarkLight1337mzusman
authored andcommitted
[VLM] Enable tokenized inputs for merged multi-modal processor (vllm-project#11900)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent ad07590 commit 11500ca

File tree

12 files changed

+207
-77
lines changed

12 files changed

+207
-77
lines changed

tests/multimodal/test_processing.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
649649
)
650650

651651

652-
def _test_processing_cache_correctness(
652+
def _test_processing_correctness(
653653
model_id: str,
654654
modalities: dict[str, bool],
655655
hit_rate: float,
@@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
691691
baseline_processor = factories.build_processor(ctx, cache=None)
692692
cached_processor = factories.build_processor(ctx, cache=cache)
693693
dummy_inputs = baseline_processor.dummy_inputs
694+
tokenizer = baseline_processor.info.get_tokenizer()
694695

695696
rng = np.random.RandomState(0)
696697

@@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
747748
)
748749

749750
assert baseline_result == cached_result, (
750-
f"Failed ({batch_idx=}, {mm_data=})")
751+
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
752+
753+
baseline_tokenized_result = baseline_processor.apply(
754+
tokenizer.encode(prompt),
755+
mm_data=mm_data,
756+
hf_processor_mm_kwargs={},
757+
)
758+
759+
assert baseline_result == baseline_tokenized_result, (
760+
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
761+
762+
cached_tokenized_result = cached_processor.apply(
763+
tokenizer.encode(prompt),
764+
mm_data=mm_data,
765+
hf_processor_mm_kwargs={},
766+
)
767+
768+
assert cached_result == cached_tokenized_result, (
769+
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
751770

752771

753772
# yapf: disable
@@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
771790
@pytest.mark.parametrize("num_batches", [32])
772791
@pytest.mark.parametrize("simplify_rate", [1.0])
773792
# yapf: enable
774-
def test_processing_cache_correctness(
793+
def test_processing_correctness(
775794
model_id: str,
776795
modalities: dict[str, bool],
777796
hit_rate: float,
778797
num_batches: int,
779798
simplify_rate: float,
780799
):
781-
_test_processing_cache_correctness(
800+
_test_processing_correctness(
782801
model_id,
783802
modalities,
784803
hit_rate=hit_rate,
@@ -795,7 +814,7 @@ def test_processing_cache_correctness(
795814
@pytest.mark.parametrize("num_batches", [32])
796815
@pytest.mark.parametrize("simplify_rate", [1.0])
797816
# yapf: enable
798-
def test_processing_cache_correctness_phi3v(
817+
def test_processing_correctness_phi3v(
799818
model_id: str,
800819
modalities: dict[str, bool],
801820
hit_rate: float,
@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
809828

810829
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
811830

812-
_test_processing_cache_correctness(
831+
_test_processing_correctness(
813832
model_id,
814833
modalities,
815834
hit_rate=hit_rate,

vllm/inputs/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ class TokensPrompt(TypedDict):
4444

4545
multi_modal_data: NotRequired["MultiModalDataDict"]
4646
"""
47-
DEPRECATED: Optional multi-modal data to pass to the model,
47+
Optional multi-modal data to pass to the model,
4848
if the model supports it.
4949
"""
5050

5151
mm_processor_kwargs: NotRequired[Dict[str, Any]]
5252
"""
53-
DEPRECATED: Optional multi-modal processor kwargs to be forwarded to the
53+
Optional multi-modal processor kwargs to be forwarded to the
5454
multimodal input mapper & processor. Note that if multiple modalities
5555
have registered mappers etc for the model being considered, we attempt
5656
to pass the mm_processor_kwargs to each of them.

vllm/inputs/preprocess.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,6 @@ async def _process_multimodal_async(
279279

280280
mm_processor = self.mm_registry.create_processor(
281281
self.model_config, tokenizer)
282-
if isinstance(prompt, list):
283-
logger.warning("Passing `multi_modal_data` in TokensPrompt is"
284-
"deprecated and will be removed in a future update")
285-
prompt = tokenizer.decode(prompt)
286282
if mm_processor_kwargs is None:
287283
mm_processor_kwargs = {}
288284

vllm/model_executor/models/blip2.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,24 @@ def get_dummy_processor_inputs(
441441

442442
class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
443443

444+
def _call_hf_processor(
445+
self,
446+
prompt: str,
447+
mm_data: Mapping[str, object],
448+
mm_kwargs: Mapping[str, object],
449+
) -> BatchFeature:
450+
if not mm_data:
451+
# HF processor always adds placeholders even when there's no image
452+
tokenizer = self.info.get_tokenizer()
453+
prompt_ids = tokenizer.encode(prompt)
454+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
455+
456+
return super()._call_hf_processor(
457+
prompt=prompt,
458+
mm_data=mm_data,
459+
mm_kwargs=mm_kwargs,
460+
)
461+
444462
def _get_mm_fields_config(
445463
self,
446464
hf_inputs: BatchFeature,
@@ -469,11 +487,11 @@ def _get_prompt_replacements(
469487

470488
def apply(
471489
self,
472-
prompt_text: str,
490+
prompt: Union[str, list[int]],
473491
mm_data: MultiModalDataDict,
474492
hf_processor_mm_kwargs: Mapping[str, object],
475493
) -> MultiModalInputsV2:
476-
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
494+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
477495

478496
# Only <image> tokens should be considered as placeholders,
479497
# so we ignore the trailing bos_token

vllm/model_executor/models/chameleon.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,34 @@ def get_dummy_processor_inputs(
9999
class ChameleonMultiModalProcessor(
100100
BaseMultiModalProcessor[ChameleonProcessingInfo]):
101101

102+
def _call_hf_processor(
103+
self,
104+
prompt: str,
105+
mm_data: Mapping[str, object],
106+
mm_kwargs: Mapping[str, object],
107+
) -> BatchFeature:
108+
if not mm_data:
109+
prompt_ids = self.info.get_tokenizer().encode(prompt)
110+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
111+
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
112+
113+
return super()._call_hf_processor(
114+
prompt=prompt,
115+
mm_data=mm_data,
116+
mm_kwargs=mm_kwargs,
117+
)
118+
119+
def _apply_hf_processor_tokens_only(
120+
self,
121+
prompt_tokens: list[int],
122+
) -> list[int]:
123+
# HF processor adds sep token for chat mode
124+
tokenizer = self.info.get_tokenizer()
125+
sep_token_id: int = \
126+
tokenizer.vocab[tokenizer.sep_token] # type: ignore
127+
128+
return prompt_tokens + [sep_token_id]
129+
102130
def _get_mm_fields_config(
103131
self,
104132
hf_inputs: BatchFeature,
@@ -128,11 +156,11 @@ def _get_prompt_replacements(
128156

129157
def apply(
130158
self,
131-
prompt_text: str,
159+
prompt: Union[str, list[int]],
132160
mm_data: MultiModalDataDict,
133161
hf_processor_mm_kwargs: Mapping[str, object],
134162
) -> MultiModalInputsV2:
135-
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
163+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
136164

137165
# Only <image> tokens should be considered as placeholders,
138166
# so we ignore the image_start_token and image_end_token

vllm/model_executor/models/fuyu.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
""" PyTorch Fuyu model."""
1717
import math
1818
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
19-
TypedDict)
19+
TypedDict, Union)
2020

2121
import torch
2222
import torch.nn as nn
@@ -149,14 +149,10 @@ def _call_hf_processor(
149149
mm_data: Mapping[str, object],
150150
mm_kwargs: Mapping[str, object],
151151
) -> BatchFeature:
152-
153152
if not mm_data:
154153
# Avoid warning from HF logger for text-only input
155-
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
156-
# Tokenizer won't add boa_token_id by default, we add it manually.
157-
tokenizer = self.info.get_tokenizer()
158-
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
159-
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
154+
prompt_ids = self.info.get_tokenizer().encode(prompt)
155+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
160156
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
161157

162158
processed_outputs = super()._call_hf_processor(
@@ -181,6 +177,16 @@ def _call_hf_processor(
181177

182178
return processed_outputs
183179

180+
def _apply_hf_processor_tokens_only(
181+
self,
182+
prompt_tokens: list[int],
183+
) -> list[int]:
184+
# HF processor adds boa_token_id
185+
tokenizer = self.info.get_tokenizer()
186+
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
187+
188+
return prompt_tokens + [boa_token_id]
189+
184190
def _get_mm_fields_config(
185191
self,
186192
hf_inputs: BatchFeature,
@@ -223,11 +229,11 @@ def get_replacement_fuyu(item_idx: int):
223229

224230
def apply(
225231
self,
226-
prompt_text: str,
232+
prompt: Union[str, list[int]],
227233
mm_data: MultiModalDataDict,
228234
hf_processor_mm_kwargs: Mapping[str, object],
229235
) -> MultiModalInputsV2:
230-
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
236+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
231237

232238
# Only |SPEAKER| (image) tokens should be considered as placeholders,
233239
# so we ignore the trailing bos_token_id

vllm/model_executor/models/interfaces.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
3939
4040
The output embeddings must be one of the following formats:
4141
42-
- A list or tuple of 2D tensors, where each tensor corresponds to
43-
each input multimodal data item (e.g, image).
42+
- A list or tuple of 2D tensors, where each tensor corresponds to
43+
each input multimodal data item (e.g, image).
4444
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
4545
4646
Note:
47-
The returned multimodal embeddings must be in the same order as
48-
the appearances of their corresponding multimodal data item in the
47+
The returned multimodal embeddings must be in the same order as
48+
the appearances of their corresponding multimodal data item in the
4949
input prompt.
5050
"""
5151
...

vllm/model_executor/models/llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
724724

725725
def apply(
726726
self,
727-
prompt_text: str,
727+
prompt: Union[str, list[int]],
728728
mm_data: MultiModalDataDict,
729729
hf_processor_mm_kwargs: Mapping[str, object],
730730
) -> MultiModalInputsV2:
@@ -737,7 +737,7 @@ def apply(
737737
image_height=-1,
738738
)
739739

740-
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
740+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
741741

742742
mm_items = self._to_mm_items(mm_data)
743743
mm_item_counts = mm_items.get_all_counts()
@@ -760,7 +760,7 @@ def get_replacement_mantis(item_idx: int):
760760
)
761761
])
762762

763-
prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
763+
prompt_ids, prompt, _ = self._apply_prompt_replacements(
764764
result["prompt_token_ids"],
765765
mantis_mm_repls,
766766
mm_item_counts,
@@ -788,7 +788,7 @@ def get_replacement_mantis(item_idx: int):
788788

789789
return MultiModalInputsV2(
790790
type="multimodal",
791-
prompt=prompt_text,
791+
prompt=prompt,
792792
prompt_token_ids=prompt_ids,
793793
mm_kwargs=mm_kwargs,
794794
mm_placeholders=mm_placeholder_ranges,

vllm/model_executor/models/phi3v.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,11 +481,11 @@ def _apply_prompt_replacements(
481481

482482
def apply(
483483
self,
484-
prompt_text: str,
484+
prompt: Union[str, list[int]],
485485
mm_data: MultiModalDataDict,
486486
hf_processor_mm_kwargs: Mapping[str, object],
487487
) -> MultiModalInputsV2:
488-
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
488+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
489489

490490
# Only <|image|> tokens should be considered as placeholders,
491491
# so we ignore the trailing bos_token_id

vllm/model_executor/models/ultravox.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,8 @@ def _call_hf_processor(
138138
) -> BatchFeature:
139139
# Text-only input not supported in composite processor
140140
if not mm_data:
141-
tokenizer = self.info.get_tokenizer()
142-
143-
prompt_ids = tokenizer.encode(
144-
prompt,
145-
add_special_tokens=False, # type: ignore
146-
)
141+
prompt_ids = self.info.get_tokenizer().encode(prompt)
142+
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
147143
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
148144

149145
mm_data = dict(mm_data)
@@ -188,6 +184,16 @@ def _call_hf_processor(
188184
)
189185
return BatchFeature(combined_outputs)
190186

187+
def _apply_hf_processor_tokens_only(
188+
self,
189+
prompt_tokens: list[int],
190+
) -> list[int]:
191+
# HF processor omits bos_token_id by setting add_special_tokens=False
192+
tokenizer = self.info.get_tokenizer()
193+
assert prompt_tokens[0] == tokenizer.bos_token_id
194+
195+
return prompt_tokens[1:]
196+
191197
def _get_mm_fields_config(
192198
self,
193199
hf_inputs: BatchFeature,

0 commit comments

Comments
 (0)