Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix mm_limits access for merged multi-modal processor #12252

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[Bugfix] Fix mm_limits access for merged multi-modal processor
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Jan 21, 2025
commit 1841279462efa19af95b4478087a2b4371bdf04c
4 changes: 2 additions & 2 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def processing_info(self) -> BaseProcessingInfo:
def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]:
return self.processor.dummy_inputs

def _get_mm_limits(self) -> Mapping[str, int]:
def get_mm_limits(self) -> Mapping[str, int]:
mm_config = self.processing_info.ctx.get_mm_config()
mm_limit_per_prompt = mm_config.limit_per_prompt

Expand Down Expand Up @@ -146,7 +146,7 @@ def get_dummy_data(self, seq_len: int) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData

mm_counts = self._get_mm_limits()
mm_counts = self.get_mm_limits()

info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
Expand Down
19 changes: 14 additions & 5 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .profiling import BaseDummyInputsBuilder
from .profiling import BaseDummyInputsBuilder, MultiModalProfiler
from .utils import cached_get_tokenizer
from .video import VideoPlugin

Expand Down Expand Up @@ -282,13 +282,13 @@ def get_max_tokens_per_item_by_nonzero_modality(
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
limits_per_plugin = self._limits_by_model[model_config]
mm_limits = self.get_mm_limits_per_prompt(model_config)

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
if limits_per_plugin[key] > 0
if mm_limits[key] > 0
}

def get_max_tokens_by_modality(
Expand All @@ -304,10 +304,10 @@ def get_max_tokens_by_modality(
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
limits_per_plugin = self._limits_by_model[model_config]
mm_limits = self.get_mm_limits_per_prompt(model_config)

return {
key: limits_per_plugin[key] * max_tokens_per_mm_item
key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
}
Expand Down Expand Up @@ -371,6 +371,15 @@ def get_mm_limits_per_prompt(
Note:
This should be called after :meth:`init_mm_limits_per_prompt`.
"""
if self.has_processor(model_config):
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
processor = self.create_processor(model_config, tokenizer)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()

return self._limits_by_model[model_config]

def register_processor(
Expand Down
Loading