Skip to content

Commit

Permalink
[Misc] Clean up multi-modal processor (vllm-project#11207)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 authored Dec 15, 2024
1 parent a1c0205 commit b10609e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
5 changes: 1 addition & 4 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def run_fuyu(question: str, modality: str):
def run_phi3v(question: str, modality: str):
assert modality == "image"

prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n" # noqa: E501
# Note: The default setting of max_num_seqs (256) and
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
prompt = f"<|user|>\n<|image_1|>\n{question}<|end|>\n<|assistant|>\n"

# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
Expand Down
17 changes: 8 additions & 9 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pytest

from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -314,8 +313,8 @@ def test_find_replace_text(
result = replace_text_matches(
prompt,
matches,
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
{key: mm_count
for key in repl_by_key},
)

# Only displayed on error
Expand Down Expand Up @@ -380,8 +379,8 @@ def test_find_replace_tokens(
result = replace_token_matches(
prompt,
matches,
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
{key: mm_count
for key in repl_by_key},
)

# Only displayed on error
Expand Down Expand Up @@ -476,7 +475,7 @@ def test_iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
{key: 3 for key in repl_by_key},
))

# Only displayed on error
Expand Down
48 changes: 23 additions & 25 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,17 @@ def _resolve_matches(
def _replace_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[_S]:
out_seqs = list[_S]()
prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items}
next_idx_by_modality = {modality: 0 for modality in mm_item_counts}

for match in _resolve_matches(prompt, matches):
modality = match.modality
modal_items = mm_items[modality]

item_idx = next_idx_by_modality[modality]
if item_idx >= len(modal_items):
if item_idx >= mm_item_counts[modality]:
continue

start_idx = match.start_idx
Expand All @@ -441,27 +440,27 @@ def _replace_matches(
def replace_token_matches(
prompt: list[int],
matches: Sequence[_PromptReplacementTokenMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

token_id_seqs = _replace_matches(prompt, matches, mm_items)
token_id_seqs = _replace_matches(prompt, matches, mm_item_counts)

return flatten_2d_lists(token_id_seqs)


def replace_text_matches(
prompt: str,
matches: Sequence[_PromptReplacementTextMatch],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt

texts = _replace_matches(prompt, matches, mm_items)
texts = _replace_matches(prompt, matches, mm_item_counts)

return "".join(texts)

Expand All @@ -470,9 +469,9 @@ def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[_BoundPromptReplacement],
modal_items: list[Any],
modal_item_count: int,
) -> Iterable[_PlaceholderInfo]:
if len(modal_items) == 0:
if modal_item_count == 0:
return

prompt_len = len(prompt)
Expand All @@ -499,7 +498,7 @@ def _iter_modality_placeholders(
)

item_index += 1
if item_index >= len(modal_items):
if item_index >= modal_item_count:
return

# Exclude overlapping matches
Expand All @@ -514,7 +513,7 @@ def _iter_modality_placeholders(
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement],
prompt: list[int],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> Iterable[_PlaceholderInfo]:
"""
Yield each set of placeholder tokens found in :code:`prompt`.
Expand All @@ -523,13 +522,13 @@ def iter_placeholders(
"""
repls_by_modality = dict(full_groupby_modality(prompt_repls))

for modality, modal_items in mm_items.items():
for modality, modal_item_count in mm_item_counts.items():
if modality in repls_by_modality:
yield from _iter_modality_placeholders(
prompt,
modality,
repls_by_modality[modality],
modal_items,
modal_item_count,
)


Expand Down Expand Up @@ -590,10 +589,10 @@ def _find_placeholders(
self,
all_prompt_repls: Sequence[_BoundPromptReplacement],
new_token_ids: list[int],
mm_items: MultiModalDataItems,
mm_item_counts: Mapping[str, int],
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts))

def _apply_hf_processor(
self,
Expand Down Expand Up @@ -655,10 +654,9 @@ def _bind_prompt_replacements(

def _apply_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self._get_tokenizer()

Expand All @@ -675,13 +673,13 @@ def _apply_prompt_replacements(
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if all(
len(matches) >= len(mm_items[modality])
len(matches) >= mm_item_counts[modality]
for modality, matches in full_groupby_modality(token_matches)
): # yapf: disable
token_ids = replace_token_matches(
token_ids,
token_matches,
mm_items,
mm_item_counts,
)

text = _decode(tokenizer, token_ids)
Expand All @@ -693,14 +691,14 @@ def _apply_prompt_replacements(
text = replace_text_matches(
text,
text_matches,
mm_items,
mm_item_counts,
)

token_ids = _encode(tokenizer, text)
matched_repls = [match.prompt_repl for match in text_matches]

placeholders = self._find_placeholders(matched_repls, token_ids,
mm_items)
mm_item_counts)

return token_ids, text, placeholders

Expand Down Expand Up @@ -737,8 +735,9 @@ def apply(

# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
mm_item_counts = {m: len(items) for m, items in mm_items.items()}
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids, mm_items)
prompt_ids, mm_item_counts)

if all_placeholders:
prompt_text = _decode(tokenizer, prompt_ids)
Expand All @@ -748,10 +747,9 @@ def apply(
prompt_text,
all_placeholders,
) = self._apply_prompt_replacements(
mm_items,
hf_inputs,
prompt_ids,
all_prompt_repls,
mm_item_counts,
)

mm_placeholders = {
Expand Down

0 comments on commit b10609e

Please sign in to comment.