Skip to content

[Bugfix] Clean up multi-modal processors #14417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,15 @@ def compute_hash(self) -> str:
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str

def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.

If not set by the user, this defaults to `1`.
"""
return self.limit_per_prompt.get(modality, 1)

# TODO: Add configs to init vision tower or not.


Expand Down
56 changes: 27 additions & 29 deletions vllm/model_executor/models/deepseek_vl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from transformers import BatchFeature

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand All @@ -25,8 +24,8 @@
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
Expand All @@ -42,8 +41,6 @@
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)

logger = init_logger(__name__)

# The image token id may be various
_IMAGE_TOKEN = "<image>"

Expand Down Expand Up @@ -216,30 +213,6 @@ def get_dummy_processor_inputs(
class DeepseekVL2MultiModalProcessor(
BaseMultiModalProcessor[DeepseekVL2ProcessingInfo]):

def __init__(
self,
info: DeepseekVL2ProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[DeepseekVL2ProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)

mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] > 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"image limit larger than 2.")

def _call_hf_processor(
self,
prompt: str,
Expand Down Expand Up @@ -316,6 +289,31 @@ def get_replacement_deepseek_vl2(item_idx: int):
)
]

def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)

return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)


@MULTIMODAL_REGISTRY.register_processor(
DeepseekVL2MultiModalProcessor,
Expand Down
58 changes: 28 additions & 30 deletions vllm/model_executor/models/h2ovl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,19 @@
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Mapping, Sequence
from typing import Optional
from typing import Optional, Union

import torch
from PIL import Image
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer

from .intern_vit import InternVisionModel
Expand All @@ -32,8 +30,6 @@
InternVLMultiModalProcessor, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios)

logger = init_logger(__name__)


def resolve_h2ovl_min_max_num(
*,
Expand Down Expand Up @@ -465,29 +461,6 @@ def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
):

def __init__(self,
info: H2OVLProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)

mm_limit = self.info.ctx.model_config.multimodal_config.limit_per_prompt
if self.cache is not None and mm_limit["image"] >= 2:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")

def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
Expand Down Expand Up @@ -543,6 +516,31 @@ def get_replacement_internvl(item_idx: int):
)
]

def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)

return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)


@MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llava_next_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _get_max_video_frames(self, max_tokens: int) -> int:

def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_videos = mm_config.get_limit_per_prompt("video")

max_total_frames = self._get_max_video_frames(seq_len)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int:

def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")

max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/minicpmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:

def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_audios = mm_config.limit_per_prompt.get("audio", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio")

# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def get_max_video_frames(self, max_tokens: int) -> int:

def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")

# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
image_token_id = mm_encoder.special_ids.img

mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1)
num_images = mm_config.get_limit_per_prompt("image")

# dummy size
size = 256
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,8 @@ def _get_max_video_frames(self, max_tokens: int) -> int:

def get_num_frames_with_most_features(self, seq_len: int) -> int:
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
max_videos = mm_config.limit_per_prompt.get("video", 1)
max_images = mm_config.get_limit_per_prompt("image")
max_videos = mm_config.get_limit_per_prompt("video")

max_image_tokens = self.get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
Expand Down
4 changes: 2 additions & 2 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,10 +984,10 @@ def _to_mm_items(
before passing them to :meth:`_get_hf_mm_data`.
"""
mm_items = self.data_parser.parse_mm_data(mm_data)
mm_config = self.info.ctx.get_mm_config()

mm_limits = self.info.ctx.get_mm_config().limit_per_prompt
for modality, items in mm_items.items():
limit = mm_limits.get(modality, 1)
limit = mm_config.get_limit_per_prompt(modality)
if len(items) > limit:
raise ValueError(
f"You set {modality}={limit} (or defaulted to 1) in "
Expand Down
4 changes: 1 addition & 3 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,10 @@ def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:

def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt

supported_mm_limits = self.processing_info.get_supported_mm_limits()

mm_limits = {
modality: mm_limit_per_prompt.get(modality, 1)
modality: mm_config.get_limit_per_prompt(modality)
for modality in supported_mm_limits
}

Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def init_mm_limits_per_prompt(
# TODO: Automatically determine the limits based on budget
# once more models support multi-image inputs
limits_per_plugin = {
key: config_limits_per_plugin.get(key, 1)
key: multimodal_config.get_limit_per_prompt(key)
for key in self._plugins
}

Expand Down