Skip to content

Commit ec79b67

Browse files
authored
[Misc][V1] Avoid using envs.VLLM_USE_V1 in mm processing (#14256)
Signed-off-by: Roger Wang <ywang@roblox.com>
1 parent 32985be commit ec79b67

File tree

7 files changed

+38
-8
lines changed

7 files changed

+38
-8
lines changed

vllm/inputs/preprocess.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def _process_multimodal(
254254
mm_data: MultiModalDataDict,
255255
mm_processor_kwargs: Optional[Mapping[str, object]],
256256
lora_request: Optional[LoRARequest],
257+
return_mm_hashes: bool = False,
257258
) -> MultiModalInputs:
258259
"""
259260
Apply the model's multi-modal processor to a multi-modal prompt,
@@ -274,14 +275,16 @@ def _process_multimodal(
274275
if mm_processor_kwargs is None:
275276
mm_processor_kwargs = {}
276277

277-
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
278+
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
279+
return_mm_hashes)
278280

279281
async def _process_multimodal_async(
280282
self,
281283
prompt: Union[str, List[int]],
282284
mm_data: MultiModalDataDict,
283285
mm_processor_kwargs: Optional[Mapping[str, object]],
284286
lora_request: Optional[LoRARequest],
287+
return_mm_hashes: bool = False,
285288
) -> MultiModalInputs:
286289
"""Async version of :meth:`_process_multimodal`."""
287290
# At the moment on model (PrithviGeoSpatialMAE) requires to be
@@ -299,13 +302,15 @@ async def _process_multimodal_async(
299302
if mm_processor_kwargs is None:
300303
mm_processor_kwargs = {}
301304

302-
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs)
305+
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
306+
return_mm_hashes)
303307

304308
def _prompt_to_llm_inputs(
305309
self,
306310
prompt: SingletonPrompt,
307311
request_id: str,
308312
lora_request: Optional[LoRARequest] = None,
313+
return_mm_hashes: bool = False,
309314
) -> SingletonInputs:
310315
"""
311316
Extract the singleton inputs from a prompt.
@@ -315,6 +320,7 @@ def _prompt_to_llm_inputs(
315320
* request_id
316321
* prompt: single encoder or decoder input prompt
317322
* lora_request: this is only valid for decoder prompts
323+
* return_mm_hashes: whether to return multimodal hashes
318324
319325
Returns:
320326
@@ -349,6 +355,7 @@ def _prompt_to_llm_inputs(
349355
multi_modal_data,
350356
mm_processor_kwargs,
351357
lora_request=lora_request,
358+
return_mm_hashes=return_mm_hashes,
352359
)
353360

354361
return token_inputs(
@@ -695,6 +702,7 @@ def _process_decoder_only_prompt(
695702
request_id: str,
696703
lora_request: Optional[LoRARequest] = None,
697704
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
705+
return_mm_hashes: bool = False,
698706
) -> DecoderOnlyInputs:
699707
"""
700708
For decoder-only models:
@@ -706,6 +714,7 @@ def _process_decoder_only_prompt(
706714
* request_id
707715
* lora_request
708716
* prompt_adapter_request
717+
* return_mm_hashes
709718
710719
Returns:
711720
@@ -729,6 +738,7 @@ async def _process_decoder_only_prompt_async(
729738
request_id: str,
730739
lora_request: Optional[LoRARequest] = None,
731740
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
741+
return_mm_hashes: bool = False,
732742
) -> DecoderOnlyInputs:
733743
"""Async version of :meth:`_process_decoder_only_prompt`."""
734744
prompt_comps = await self._prompt_to_llm_inputs_async(
@@ -748,9 +758,13 @@ def preprocess(
748758
request_id: str,
749759
lora_request: Optional[LoRARequest] = None,
750760
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
761+
return_mm_hashes: bool = False,
751762
) -> ProcessorInputs:
752763
"""Preprocess the input prompt."""
753764
if self.model_config.is_encoder_decoder:
765+
assert not return_mm_hashes, (
766+
"Multimodal hashes for encoder-decoder models should not be ",
767+
"returned until they are supported on vLLM V1.")
754768
# Encoder-decoder model requires special mapping of
755769
# input prompts to encoder & decoder
756770
return self._process_encoder_decoder_prompt(
@@ -768,6 +782,7 @@ def preprocess(
768782
request_id=request_id,
769783
lora_request=lora_request,
770784
prompt_adapter_request=prompt_adapter_request,
785+
return_mm_hashes=return_mm_hashes,
771786
)
772787

773788
async def preprocess_async(
@@ -776,9 +791,13 @@ async def preprocess_async(
776791
request_id: str,
777792
lora_request: Optional[LoRARequest] = None,
778793
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
794+
return_mm_hashes: bool = False,
779795
) -> ProcessorInputs:
780796
"""Async version of :meth:`preprocess`."""
781797
if self.model_config.is_encoder_decoder:
798+
assert not return_mm_hashes, (
799+
"Multimodal hashes for encoder-decoder models should not be ",
800+
"returned until they are supported on vLLM V1.")
782801
# Encoder-decoder model requires special mapping of
783802
# input prompts to encoder & decoder
784803
return await self._process_encoder_decoder_prompt_async(
@@ -796,4 +815,5 @@ async def preprocess_async(
796815
request_id=request_id,
797816
lora_request=lora_request,
798817
prompt_adapter_request=prompt_adapter_request,
818+
return_mm_hashes=return_mm_hashes,
799819
)

vllm/model_executor/models/llava.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,7 @@ def apply(
767767
prompt: Union[str, list[int]],
768768
mm_data: MultiModalDataDict,
769769
hf_processor_mm_kwargs: Mapping[str, object],
770+
return_mm_hashes: bool = False,
770771
) -> MultiModalInputs:
771772
hf_config = self.info.get_hf_config()
772773
image_token_id = hf_config.image_token_index
@@ -777,7 +778,8 @@ def apply(
777778
image_height=-1,
778779
)
779780

780-
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
781+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
782+
return_mm_hashes)
781783

782784
mm_items = self._to_mm_items(mm_data)
783785
mm_item_counts = mm_items.get_all_counts()

vllm/model_executor/models/minicpmv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,7 @@ def apply(
780780
prompt: Union[str, List[int]],
781781
mm_data: MultiModalDataDict,
782782
hf_processor_mm_kwargs: Mapping[str, object],
783+
return_mm_hashes: bool = False,
783784
) -> MultiModalInputs:
784785
supported_mm_modalities = self.info.get_supported_mm_modalities()
785786
if isinstance(prompt, list):
@@ -791,7 +792,8 @@ def apply(
791792
[index for index, m in enumerate(matches) if m == modality])
792793
for modality in supported_mm_modalities
793794
}
794-
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
795+
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
796+
return_mm_hashes)
795797
# Exclude <image_id>x</image_id> from placeholders
796798
if "image" in result["mm_placeholders"] and \
797799
self.info.get_model_version() == (2, 6):

vllm/model_executor/models/mllama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,10 @@ def apply(
175175
prompt: Union[str, list[int]],
176176
mm_data: MultiModalDataDict,
177177
hf_processor_mm_kwargs: Mapping[str, object],
178+
return_mm_hashes: bool = False,
178179
) -> MultiModalEncDecInputs:
179-
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
180+
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs,
181+
return_mm_hashes)
180182

181183
# Check that the number of image tokens in the decoder prompt matches
182184
# the number of images provided in mm_data

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def apply(
9393
prompt: Union[str, list[int]],
9494
mm_data: MultiModalDataDict,
9595
hf_processor_mm_kwargs: Mapping[str, object],
96+
return_mm_hashes: bool = False,
9697
) -> MultiModalInputs:
9798
mm_kwargs = {}
9899

vllm/multimodal/processing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
1515
from typing_extensions import assert_never
1616

17-
import vllm.envs as envs
1817
from vllm.inputs import InputProcessingContext
1918
from vllm.logger import init_logger
2019
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
@@ -1435,6 +1434,7 @@ def apply(
14351434
prompt: Union[str, list[int]],
14361435
mm_data: MultiModalDataDict,
14371436
hf_processor_mm_kwargs: Mapping[str, object],
1437+
return_mm_hashes: bool = False,
14381438
) -> MultiModalInputs:
14391439
"""
14401440
Process multi-modal inputs to be used in vLLM.
@@ -1451,11 +1451,11 @@ def apply(
14511451
"""
14521452
mm_items = self._to_mm_items(mm_data)
14531453

1454-
# Create MM hashes (only used in V1)
1454+
# Create MM hashes to be returned (only used in V1)
14551455
# TODO: Use these hash keys for caching operations in apply_hf_processor
14561456
# instead of rehashing.
14571457

1458-
if envs.VLLM_USE_V1:
1458+
if return_mm_hashes:
14591459
model_id = self.info.model_id
14601460
mm_hashes = {
14611461
modality: [
@@ -1554,6 +1554,7 @@ def apply(
15541554
prompt: Union[str, list[int]],
15551555
mm_data: MultiModalDataDict,
15561556
hf_processor_mm_kwargs: Mapping[str, object],
1557+
return_mm_hashes: bool = False,
15571558
) -> MultiModalEncDecInputs:
15581559
"""
15591560
Process multi-modal inputs to be used in vLLM.
@@ -1567,6 +1568,7 @@ def apply(
15671568
encoder_prompt,
15681569
mm_data,
15691570
hf_processor_mm_kwargs,
1571+
return_mm_hashes,
15701572
)
15711573

15721574
tokenizer = self.info.get_tokenizer()

vllm/v1/engine/processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def process_inputs(
131131
request_id=request_id,
132132
lora_request=lora_request,
133133
prompt_adapter_request=prompt_adapter_request,
134+
return_mm_hashes=self.use_hash,
134135
)
135136
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
136137

0 commit comments

Comments
 (0)