Skip to content

Commit d9d4f95

Browse files
DarkLight1337wuisawesome
authored andcommitted
[Misc] Clean up Qwen2.5-Omni code (vllm-project#17301)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent c7f8500 commit d9d4f95

File tree

2 files changed

+75
-94
lines changed

2 files changed

+75
-94
lines changed

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@
5151
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
5252
from vllm.model_executor.sampling_metadata import SamplingMetadata
5353
from vllm.multimodal import MULTIMODAL_REGISTRY
54-
from vllm.multimodal.hasher import MultiModalHasher
5554
from vllm.multimodal.inputs import (ImageItem, ModalityData,
5655
MultiModalDataDict, MultiModalFieldConfig,
57-
MultiModalInputs, MultiModalKwargs,
58-
NestedTensors)
56+
MultiModalKwargs, NestedTensors)
5957
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
6058
ModalityDataItems, MultiModalDataItems,
6159
MultiModalDataParser)
@@ -279,46 +277,17 @@ def _get_mm_fields_config(
279277
) -> Mapping[str, MultiModalFieldConfig]:
280278
return _qwen2_5_omni_thinker_field_config(hf_inputs)
281279

282-
def apply(
280+
def _maybe_apply_prompt_updates(
283281
self,
284-
prompt: Union[str, list[int]],
285-
mm_data: MultiModalDataDict,
282+
mm_items: MultiModalDataItems,
286283
hf_processor_mm_kwargs: Mapping[str, object],
287-
return_mm_hashes: bool = False,
288-
) -> MultiModalInputs:
284+
prompt_ids: list[int],
285+
mm_kwargs: MultiModalKwargs,
286+
is_update_applied: bool,
287+
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
289288
"""
290289
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
291290
"""
292-
mm_items = self._to_mm_items(mm_data)
293-
294-
# Create MM hashes to be returned (only used in V1)
295-
# TODO: Use these hash keys for caching operations in apply_hf_processor
296-
# instead of rehashing.
297-
298-
if return_mm_hashes:
299-
model_id = self.info.model_id
300-
mm_hashes = {
301-
modality: [
302-
MultiModalHasher.hash_kwargs(model_id=model_id,
303-
**{modality: item},
304-
**hf_processor_mm_kwargs)
305-
for item in items
306-
]
307-
for modality, items in mm_items.items()
308-
}
309-
else:
310-
mm_hashes = None
311-
312-
(
313-
prompt_ids,
314-
mm_kwargs,
315-
is_update_applied,
316-
) = self._cached_apply_hf_processor(
317-
prompt,
318-
mm_items,
319-
hf_processor_mm_kwargs,
320-
)
321-
322291
unbound_prompt_updates = self._get_prompt_updates(
323292
mm_items,
324293
hf_processor_mm_kwargs,
@@ -364,22 +333,10 @@ def apply(
364333
tokenizer = self.info.get_tokenizer()
365334
prompt = decode_tokens(tokenizer, prompt_ids)
366335

367-
mm_placeholder_ranges = {
368-
modality: [item.to_range() for item in placeholders]
369-
for modality, placeholders in mm_placeholders.items()
370-
}
371-
372336
if use_audio_in_video:
373337
mm_kwargs["use_audio_in_video"] = True
374338

375-
return MultiModalInputs(
376-
type="multimodal",
377-
prompt=prompt,
378-
prompt_token_ids=prompt_ids,
379-
mm_kwargs=mm_kwargs,
380-
mm_hashes=mm_hashes,
381-
mm_placeholders=mm_placeholder_ranges,
382-
)
339+
return prompt_ids, prompt, mm_placeholders
383340

384341
def _get_prompt_updates(
385342
self,

vllm/multimodal/processing.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1569,56 +1569,35 @@ def _validate_mm_placeholders(
15691569
"model (usually arising from an inconsistency between "
15701570
"`_call_hf_processor` and `_get_prompt_updates`).")
15711571

1572-
def apply(
1572+
def _hash_mm_items(
15731573
self,
1574-
prompt: Union[str, list[int]],
1575-
mm_data: MultiModalDataDict,
1574+
mm_items: MultiModalDataItems,
15761575
hf_processor_mm_kwargs: Mapping[str, object],
1577-
return_mm_hashes: bool = False,
1578-
) -> MultiModalInputs:
1579-
"""
1580-
Process multi-modal inputs to be used in vLLM.
1576+
) -> dict[str, list[str]]:
1577+
"""Create MM hashes to be returned (only used in V1)."""
15811578

1582-
The main steps are:
1583-
1584-
1. Apply HF Processor on prompt text and multi-modal data together,
1585-
outputting token IDs and processed tensors.
1586-
2. Find and update sequences in the token IDs with placeholder tokens.
1587-
The number of placeholder tokens equals the feature size of the
1588-
multi-modal data outputted by the multi-modal encoder.
1589-
3. Extract information about the placeholder tokens from the
1590-
processed token IDs.
1591-
"""
1592-
mm_items = self._to_mm_items(mm_data)
1593-
1594-
# Create MM hashes to be returned (only used in V1)
15951579
# TODO: Use these hash keys for caching operations in apply_hf_processor
15961580
# instead of rehashing.
1581+
model_id = self.info.model_id
15971582

1598-
if return_mm_hashes:
1599-
model_id = self.info.model_id
1600-
mm_hashes = {
1601-
modality: [
1602-
MultiModalHasher.hash_kwargs(model_id=model_id,
1603-
**{modality: item},
1604-
**hf_processor_mm_kwargs)
1605-
for item in items
1606-
]
1607-
for modality, items in mm_items.items()
1608-
}
1609-
else:
1610-
mm_hashes = None
1611-
1612-
(
1613-
prompt_ids,
1614-
mm_kwargs,
1615-
is_update_applied,
1616-
) = self._cached_apply_hf_processor(
1617-
prompt,
1618-
mm_items,
1619-
hf_processor_mm_kwargs,
1620-
)
1583+
return {
1584+
modality: [
1585+
MultiModalHasher.hash_kwargs(model_id=model_id,
1586+
**{modality: item},
1587+
**hf_processor_mm_kwargs)
1588+
for item in items
1589+
]
1590+
for modality, items in mm_items.items()
1591+
}
16211592

1593+
def _maybe_apply_prompt_updates(
1594+
self,
1595+
mm_items: MultiModalDataItems,
1596+
hf_processor_mm_kwargs: Mapping[str, object],
1597+
prompt_ids: list[int],
1598+
mm_kwargs: MultiModalKwargs,
1599+
is_update_applied: bool,
1600+
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
16221601
unbound_prompt_updates = self._get_prompt_updates(
16231602
mm_items,
16241603
hf_processor_mm_kwargs,
@@ -1652,6 +1631,51 @@ def apply(
16521631
)
16531632
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
16541633

1634+
return prompt_ids, prompt, mm_placeholders
1635+
1636+
def apply(
1637+
self,
1638+
prompt: Union[str, list[int]],
1639+
mm_data: MultiModalDataDict,
1640+
hf_processor_mm_kwargs: Mapping[str, object],
1641+
return_mm_hashes: bool = False,
1642+
) -> MultiModalInputs:
1643+
"""
1644+
Process multi-modal inputs to be used in vLLM.
1645+
1646+
The main steps are:
1647+
1648+
1. Apply HF Processor on prompt text and multi-modal data together,
1649+
outputting token IDs and processed tensors.
1650+
2. Find and update sequences in the token IDs with placeholder tokens.
1651+
The number of placeholder tokens equals the feature size of the
1652+
multi-modal data outputted by the multi-modal encoder.
1653+
3. Extract information about the placeholder tokens from the
1654+
processed token IDs.
1655+
"""
1656+
mm_items = self._to_mm_items(mm_data)
1657+
1658+
mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
1659+
if return_mm_hashes else None)
1660+
1661+
(
1662+
prompt_ids,
1663+
mm_kwargs,
1664+
is_update_applied,
1665+
) = self._cached_apply_hf_processor(
1666+
prompt,
1667+
mm_items,
1668+
hf_processor_mm_kwargs,
1669+
)
1670+
1671+
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
1672+
mm_items=mm_items,
1673+
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
1674+
prompt_ids=prompt_ids,
1675+
mm_kwargs=mm_kwargs,
1676+
is_update_applied=is_update_applied,
1677+
)
1678+
16551679
mm_placeholder_ranges = {
16561680
modality: [item.to_range() for item in placeholders]
16571681
for modality, placeholders in mm_placeholders.items()

0 commit comments

Comments
 (0)