Skip to content
  • Rate limit · GitHub

    Access has been restricted

    You have triggered a rate limit.

    Please wait a few minutes before you try again;
    in some cases this may take up to an hour.

  • Notifications You must be signed in to change notification settings
  • Fork 6.1k
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11396

Merged
merged 82 commits into from
Dec 27, 2024
Rate limit · GitHub

Access has been restricted

You have triggered a rate limit.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.

Merged
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
f86e148
Enable caching outside the processing loop
DarkLight1337 Dec 21, 2024
337f0d2
Make debugging easier
DarkLight1337 Dec 21, 2024
c01d38a
Update
DarkLight1337 Dec 21, 2024
84f02fb
Fix ultravox
DarkLight1337 Dec 21, 2024
9f417c2
Revert some unnecessary changes
DarkLight1337 Dec 21, 2024
00b765b
Merge branch 'main' into mm-fields
DarkLight1337 Dec 22, 2024
2ed431e
Add test and fix some issues
DarkLight1337 Dec 23, 2024
baaf551
Update
DarkLight1337 Dec 23, 2024
f5dbcb8
Fix
DarkLight1337 Dec 23, 2024
afd3f4f
Rework
DarkLight1337 Dec 23, 2024
6172450
Rename the test
DarkLight1337 Dec 23, 2024
416943d
Update count
DarkLight1337 Dec 23, 2024
86f2786
Rename
DarkLight1337 Dec 23, 2024
f5b6214
Some fixes
DarkLight1337 Dec 23, 2024
8a68e87
Cleanup
DarkLight1337 Dec 23, 2024
ab7e84b
Skip unspecified fields
DarkLight1337 Dec 23, 2024
9f2cdaa
Fix equality checking
DarkLight1337 Dec 23, 2024
d11e833
Consolidate common code
DarkLight1337 Dec 23, 2024
5fee280
Improve error message
DarkLight1337 Dec 23, 2024
6182fd6
Cleanup
DarkLight1337 Dec 23, 2024
e1214cf
Fix Pixtral-HF
DarkLight1337 Dec 23, 2024
c717bce
Fix missing mm_count key
DarkLight1337 Dec 23, 2024
023890e
Fix qwen2-vl
DarkLight1337 Dec 23, 2024
b5e5b8a
Fix Qwen2-VL
DarkLight1337 Dec 23, 2024
cf24a1b
Fix Qwen2-VL and Qwen2-Audio
DarkLight1337 Dec 23, 2024
73271e9
Debug Phi3V
DarkLight1337 Dec 23, 2024
e30deec
Consolidate common code
DarkLight1337 Dec 23, 2024
ea6f8b5
Try to fix Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
10ae755
Remove benchmark
DarkLight1337 Dec 23, 2024
85c5e2c
Fix token mismatch in Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
4873ff8
Update max image tokens
DarkLight1337 Dec 23, 2024
4dbb5a3
Strictly check the number of placeholder tokens
DarkLight1337 Dec 23, 2024
6dbae81
Fix doc failure
DarkLight1337 Dec 23, 2024
fb51c9b
Test and fix Mantis processor
DarkLight1337 Dec 24, 2024
91cbd63
Fix embedding inputs
DarkLight1337 Dec 24, 2024
6bee6ba
Update entrypoints tests
DarkLight1337 Dec 24, 2024
cfa2ce8
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
fa54292
Clean up
DarkLight1337 Dec 24, 2024
cbf79be
Avoid extra placeholder in phi3v
DarkLight1337 Dec 24, 2024
9cd38b1
Fix OOM
DarkLight1337 Dec 24, 2024
14dcdd5
Fix mantis processor
DarkLight1337 Dec 24, 2024
b8bd2d4
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
5045d93
Remove redundant code
DarkLight1337 Dec 24, 2024
4cac998
Still need Mantis repo for testing
DarkLight1337 Dec 24, 2024
e8afd10
Merge branch 'main' into mm-fields
DarkLight1337 Dec 25, 2024
93bba0a
Fix incorrect max image tokens (Updated in #11258)
DarkLight1337 Dec 25, 2024
ea9f888
Also cache by model ID
DarkLight1337 Dec 25, 2024
58747f6
Format
DarkLight1337 Dec 25, 2024
323657a
Update link
DarkLight1337 Dec 25, 2024
695c79e
Merge branch 'main' into mm-fields
DarkLight1337 Dec 26, 2024
c67efda
Address some comments
DarkLight1337 Dec 26, 2024
d4abec7
Move `MultiModalDataItems` to `inputs` module to address more comments
DarkLight1337 Dec 26, 2024
9f4a8be
Add documentation
DarkLight1337 Dec 26, 2024
1d5b56d
Fix circular import
DarkLight1337 Dec 26, 2024
e4c7a14
Update docs
DarkLight1337 Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use the inner processor to enable fine-grained caching
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
DarkLight1337 committed Dec 20, 2024
commit 7264d4e1647d6313727b6a7e4ee8cf1fd2509df7
41 changes: 37 additions & 4 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Mapping, NamedTuple,
from typing import (TYPE_CHECKING, Any, Callable, Literal, Mapping, NamedTuple,
Optional, Protocol, Union)

from torch import nn
@@ -111,6 +111,39 @@ def get_hf_processor(

return hf_processor

def get_modality_processor(
self,
hf_processor: ProcessorMixin,
modality_data_key: Literal["text", "images", "videos", "audios"],
) -> Callable[..., BatchFeature]:
"""
Get the HuggingFace modality-specific processor which is
a child of a :class:`transformers.ProcessorMixin`, identified by
the corresponding keyword argument in its `__call__` method.
"""
if modality_data_key == "text":
attributes = ["tokenizer"]
elif modality_data_key == "images":
attributes = ["image_processor"]
elif modality_data_key == "videos":
attributes = ["video_processor"]
elif modality_data_key == "audios":
attributes = ["audio_processor", "feature_extractor"]
else:
assert_never(modality_data_key)

modality_processor = next(
(getattr(hf_processor, attr)
for attr in attributes if hasattr(hf_processor, attr)),
None,
)
if modality_processor is None:
raise AttributeError(
f"Cannot found HuggingFace processor for "
f"{modality_data_key} inside {type(hf_processor)}")

return modality_processor


@dataclass(frozen=True)
class InputProcessingContext(InputContext):
@@ -131,15 +164,15 @@ def get_hf_processor(

def call_hf_processor(
self,
hf_processor: ProcessorMixin,
hf_processor: Union[ProcessorMixin, Callable[..., BatchFeature]],
data: Mapping[str, object],
kwargs: Optional[Mapping[str, object]] = None,
) -> BatchFeature:
assert callable(hf_processor)

if kwargs is None:
kwargs = {}

assert callable(hf_processor)

base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
106 changes: 53 additions & 53 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,8 @@
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import lru_cache, partial
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union, cast
from typing import (Any, Literal, NamedTuple, Optional, Protocol, TypeVar,
Union, cast)

import numpy as np
import torch
@@ -616,8 +617,8 @@ def maybe_log_cache_stats(self, cache: LRUCache, name: str) -> None:
def _iter_bytes_to_hash(self, key: str, obj: object) -> Iterable[bytes]:
# Recursive cases
if isinstance(obj, (list, tuple)):
for elem in obj:
yield from self._iter_bytes_to_hash(key, elem)
for i, elem in enumerate(obj):
yield from self._iter_bytes_to_hash(f"{key}.{i}", elem)
return
if isinstance(obj, dict):
for k, v in obj.items():
@@ -664,66 +665,64 @@ def _cached_call_fine(
self,
ctx: InputProcessingContext,
hf_processor: ProcessorMixin,
prompt: str,
mm_data: Mapping[str, list[object]],
text: str,
mm_data: Mapping[Literal["images", "videos", "audios"], list[Any]],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_mm_items = defaultdict[str, list[torch.Tensor]]()

num_items = len(next(iter(mm_data.values())))
for idx in range(num_items):
mm_item = {k: [v[idx]] for k, v in mm_data.items()}

self.maybe_log_cache_stats(self._fine_mm_cache, "fine_mm_cache")

processed_mm_item = self._fine_mm_cache.get_or_put(
self._hash_kwargs(**mm_item, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
mm_item,
mm_kwargs,
),
)

for k, v in processed_mm_item.items():
# Remove the extra batch dimension
processed_mm_items[k].append(v[0])

# NOTE: Some processors (e.g. llava) do not accept mm-only input,
# in which case we have to fallback to processing `prompt` and `mm_data`
# together. Therefore, we place the text processing last to avoid
# redundant computation
self.maybe_log_cache_stats(self._fine_text_cache, "fine_text_cache")

processed_text = self._fine_text_cache.get_or_put(
prompt,
text,
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
dict(text=prompt),
ctx.get_modality_processor(hf_processor, "text"),
dict(text=text),
),
)

processed_data = dict(**processed_text, **processed_mm_items)
processed_data = dict(**processed_text)
for data_key, items in mm_data.items():
processed_modal_items = defaultdict[str, list[torch.Tensor]](list)

for item in items:
self.maybe_log_cache_stats(self._fine_mm_cache,
"fine_mm_cache")

modal_item = cast(Mapping[str, object], {data_key: item})
processed_modal_item = self._fine_mm_cache.get_or_put(
self._hash_kwargs(**modal_item, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
ctx.get_modality_processor(hf_processor, data_key),
modal_item,
mm_kwargs,
),
)

for k, v in processed_modal_item.items():
# Remove the extra batch dimension
processed_modal_items[k].append(v[0])

processed_data.update(processed_modal_items)

return BatchFeature(processed_data)

def _cached_call_coarse(
self,
ctx: InputProcessingContext,
hf_processor: ProcessorMixin,
prompt: str,
text: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
self.maybe_log_cache_stats(self._coarse_cache, "coarse_cache")

processed_data = self._coarse_cache.get_or_put(
self._hash_kwargs(text=prompt, **mm_data, **mm_kwargs),
self._hash_kwargs(text=text, **mm_data, **mm_kwargs),
default_factory=partial(
ctx.call_hf_processor,
hf_processor,
dict(text=prompt, **mm_data),
dict(text=text, **mm_data),
mm_kwargs,
),
)
@@ -737,34 +736,35 @@ def call_hf_processor(
ctx: InputProcessingContext,
# Assumes that hf_processor has been initialized according to kwargs
hf_processor: ProcessorMixin,
prompt: str,
text: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Try to cache each item separately to improve hit rate
if mm_data and all(isinstance(v, list) for v in mm_data.values()):
extra_keys = mm_data.keys() - {"images", "videos", "audios"}
if (mm_data and not extra_keys
and all(isinstance(v, list) for v in mm_data.values())):
try:
return self._cached_call_fine(
ctx,
hf_processor,
prompt,
cast(Mapping[str, list[object]], mm_data),
mm_kwargs,
text=text,
mm_data=mm_data, # type: ignore[arg-type]
mm_kwargs=mm_kwargs,
)
except Exception:
# Failures are expected; see NOTE in `_cached_call_fine`
logger.debug(
"Failed to apply processor on each item separately",
logger.exception(
"Failed to apply processor on each item separately! "
"Falling back to coarse caching.",
stack_info=True,
)
pass

return self._cached_call_coarse(
ctx,
hf_processor,
prompt,
mm_data,
mm_kwargs,
text=text,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)


@@ -872,9 +872,9 @@ def _call_hf_processor(
return self.cache.call_hf_processor(
self.ctx,
self._get_hf_processor(**mm_kwargs),
prompt,
mm_data,
mm_kwargs,
text=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

def _apply_hf_processor(
Rate limit · GitHub

Access has been restricted

You have triggered a rate limit.

Please wait a few minutes before you try again;
in some cases this may take up to an hour.