Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Keep track of whether prompt replacements have been applied #13215

Merged
merged 12 commits into from
Feb 14, 2025
Prev Previous commit
Clean up DictEmbeddingItems
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Feb 14, 2025
commit 98a4d473dda01bf5f4530fbda23a4230442c104a
90 changes: 51 additions & 39 deletions vllm/model_executor/models/minicpmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Tuple, TypedDict, Union)

import torch
import torch.types
from torch import nn
from transformers import BatchFeature
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.whisper.modeling_whisper import (
ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder)
Expand All @@ -37,23 +37,21 @@
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ModalityData, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser,
VideoItem)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.parse import (AudioItem, DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors

from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVEmbeddingItems, MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo)
MiniCPMVMultiModalDataParser,
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
_minicpmv_field_config)
from .utils import AutoWeightsLoader, maybe_prefix

CPU_DEVICE = torch.device("cpu")

MiniCPMOEmbeddingItems = MiniCPMVEmbeddingItems


class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
Expand Down Expand Up @@ -103,28 +101,49 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
MiniCPMOAudioEmbeddingInputs]


class MiniCPMOAudioEmbeddingItems(MiniCPMOEmbeddingItems):
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))

return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
)


def __init__(self, data: Dict) -> None:
super().__init__(data, "audio")
audio_embeds = self.data.get("audio_embeds", None)
if audio_embeds is None:
raise ValueError("Incorrect type of video_embeds",
"Got type: None")
self.data["audio_embeds"] = audio_embeds
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):

def get(self, index: int) -> object:
return self.data["audio_embeds"][index]
def __init__(
self,
data: Mapping[str, torch.Tensor],
fields_config: Mapping[str, MultiModalFieldConfig],
) -> None:
super().__init__(
data,
modality="image",
fields_config=fields_config,
required_fields={"audio_embeds"},
)


class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):

def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return MiniCPMOAudioEmbeddingItems(data)
return MiniCPMOAudioEmbeddingItems(
data,
fields_config=_minicpmo_field_config(data),
)

return super()._parse_audio_data(data)


Expand Down Expand Up @@ -167,6 +186,10 @@ def get_max_audio_tokens_per_chunk(self) -> int:
def get_max_audio_chunks_with_most_features(self) -> int:
return 30

def get_max_audio_tokens(self) -> int:
return self.get_max_audio_tokens_per_chunk(
) * self.get_max_audio_chunks_with_most_features()

def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio>
Expand Down Expand Up @@ -194,7 +217,8 @@ def get_num_frames_with_most_features(self, seq_len: int) -> int:
return num_frames


class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder):
class MiniCPMODummyInputsBuilder(
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):

def get_dummy_processor_inputs(
self, seq_len: int, mm_counts: Mapping[str,
Expand Down Expand Up @@ -222,8 +246,7 @@ def get_dummy_processor_inputs(


class MiniCPMOMultiModalProcessor(
MiniCPMVMultiModalProcessor,
BaseMultiModalProcessor[MiniCPMOProcessingInfo]):
MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):

def _get_data_parser(self) -> MultiModalDataParser:
return MiniCPMOMultiModalDataParser(
Expand Down Expand Up @@ -369,21 +392,10 @@ def get_replacement_minicpmv(item_idx: int, modality: str):

def _get_mm_fields_config(
self,
hf_inputs,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))

return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
return _minicpmo_field_config(hf_inputs)


class MultiModalProjector(nn.Module):
Expand All @@ -406,7 +418,7 @@ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:

class MiniCPMWhisperEncoderLayer(nn.Module):

def __init__(self, config: WhisperConfig, layer_idx: int = None):
def __init__(self, config: WhisperConfig, layer_idx: int):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = WHISPER_ATTENTION_CLASSES[
Expand Down
Loading