Skip to content

[Bugfix] Fix profiling OOM and decouple encoder multimodal profiling #14361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
exc_ctx = pytest.raises(ValueError, match="this model only supports")

with exc_ctx:
profiler.get_dummy_data(model_config.max_model_len)
profiler.get_decoder_dummy_data(model_config.max_model_len)


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
Expand Down
6 changes: 4 additions & 2 deletions vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ def dummy_data_for_profiling(
tokenizer,
disable_cache=True)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_dummy_data(
seq_len, is_encoder_data=is_encoder_data)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else
profiler.get_decoder_dummy_data)
dummy_data = dummy_data_factory(seq_len)
else:
model_cls, _ = get_model_architecture(model_config)
if is_encoder_data:
Expand Down
84 changes: 54 additions & 30 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import Generic, TypeVar, cast

import numpy as np
import numpy.typing as npt
Expand All @@ -13,7 +13,8 @@
from vllm.inputs import DummyData
from vllm.logger import init_logger

from .inputs import MultiModalDataDict, MultiModalInputs
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo

logger = init_logger(__name__)
Expand Down Expand Up @@ -144,14 +145,10 @@ def _get_dummy_mm_inputs(
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
)

def get_dummy_data(
def get_and_validate_mm_inputs(
self,
seq_len: int,
is_encoder_data: bool = False,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData

) -> tuple[MultiModalInputs, Mapping[str, int]]:
mm_counts = self.get_mm_limits()

info = self.processing_info
Expand All @@ -167,11 +164,6 @@ def get_dummy_data(

mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore

total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
Expand All @@ -187,28 +179,60 @@ def get_dummy_data(
f"{total_placeholders_by_modality} placeholder tokens, which "
f"is not the expected {expected_placeholders_by_modality} "
"tokens.")
return mm_inputs, total_placeholders_by_modality

def get_encoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData

mm_inputs, _ = self.get_and_validate_mm_inputs(seq_len)
mm_inputs = cast(MultiModalEncDecInputs, mm_inputs)

# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
encoder_prompt_token_ids = mm_inputs["encoder_prompt_token_ids"]

total_len = len(encoder_prompt_token_ids)
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)

return DummyData(
seq_data=SequenceData.from_seqs(encoder_prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)

def get_decoder_dummy_data(
self,
seq_len: int,
) -> DummyData:
# Avoid circular import
from vllm.sequence import SequenceData

(mm_inputs, total_placeholders_by_modality
) = self.get_and_validate_mm_inputs(seq_len)

prompt_token_ids = mm_inputs["prompt_token_ids"]
total_len = len(prompt_token_ids)

# V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len and not is_encoder_data:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)

num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
if total_len > seq_len and not envs.VLLM_USE_V1:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain "
"multi-modal inputs to fail during inference, even when "
"the input text is short. To avoid this, you should "
"increase `max_model_len`, reduce `max_num_seqs`, "
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)

return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
multi_modal_data=None,
multi_modal_placeholders=None,
)
Expand All @@ -218,5 +242,5 @@ def get_dummy_data(
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)