Skip to content

[Misc] Clean up Qwen2.5-Omni code #17301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 8 additions & 51 deletions vllm/model_executor/models/qwen2_5_omni_thinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs,
NestedTensors)
MultiModalKwargs, NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
Expand Down Expand Up @@ -279,46 +277,17 @@ def _get_mm_fields_config(
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen2_5_omni_thinker_field_config(hf_inputs)

def apply(
def _maybe_apply_prompt_updates(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
prompt_ids: list[int],
mm_kwargs: MultiModalKwargs,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
mm_items = self._to_mm_items(mm_data)

# Create MM hashes to be returned (only used in V1)
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.

if return_mm_hashes:
model_id = self.info.model_id
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
else:
mm_hashes = None

(
prompt_ids,
mm_kwargs,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
)

unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
Expand Down Expand Up @@ -364,22 +333,10 @@ def apply(
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)

mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
}

if use_audio_in_video:
mm_kwargs["use_audio_in_video"] = True

return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholder_ranges,
)
return prompt_ids, prompt, mm_placeholders

def _get_prompt_updates(
self,
Expand Down
110 changes: 67 additions & 43 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,56 +1569,35 @@ def _validate_mm_placeholders(
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`).")

def apply(
def _hash_mm_items(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
) -> dict[str, list[str]]:
"""Create MM hashes to be returned (only used in V1)."""

The main steps are:

1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and update sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
mm_items = self._to_mm_items(mm_data)

# Create MM hashes to be returned (only used in V1)
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id = self.info.model_id

if return_mm_hashes:
model_id = self.info.model_id
mm_hashes = {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
else:
mm_hashes = None

(
prompt_ids,
mm_kwargs,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
)
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}

def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargs,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
Expand Down Expand Up @@ -1652,6 +1631,51 @@ def apply(
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

return prompt_ids, prompt, mm_placeholders

def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.

The main steps are:

1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and update sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
mm_items = self._to_mm_items(mm_data)

mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
if return_mm_hashes else None)

(
prompt_ids,
mm_kwargs,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
)

prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
prompt_ids=prompt_ids,
mm_kwargs=mm_kwargs,
is_update_applied=is_update_applied,
)

mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
for modality, placeholders in mm_placeholders.items()
Expand Down