Skip to content

Commit

Permalink
[Bugfix][VLM] Fix mixed-modality inference backward compatibility for…
Browse files Browse the repository at this point in the history
… V0 (vllm-project#12313)

Signed-off-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
ywang96 authored and vllmellm committed Jan 27, 2025
1 parent 42951eb commit 2f40e8c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 28 deletions.
53 changes: 44 additions & 9 deletions vllm/model_executor/models/llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def apply_pooling(self, image_features, stride=2):
return image_feature

def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return None
Expand All @@ -842,8 +842,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
str]]] = None,
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand All @@ -852,6 +851,34 @@ def get_input_embeddings(
[self.config.image_token_index, self.config.video_token_index])
return inputs_embeds

def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[NestedTensors] = None,
video_input: Optional[NestedTensors] = None,
) -> torch.Tensor:

inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_index,
)

if video_input is not None:
video_embeds = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_index,
)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -871,13 +898,21 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)

if image_input is None and video_input is None:
inputs_embeds = None
else:
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down
67 changes: 48 additions & 19 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig, MultiModalKwargs,
NestedTensors, VideoItem)
VideoItem)
from vllm.multimodal.parse import (ImageSize, ModalityDataItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
return modalities

def get_multimodal_embeddings(
self, **kwargs) -> Optional[List[Tuple[NestedTensors, str]]]:
self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:

modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
Expand All @@ -1260,8 +1260,7 @@ def get_multimodal_embeddings(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
str]]] = None,
multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
Expand All @@ -1270,6 +1269,33 @@ def get_input_embeddings(
[self.config.image_token_id, self.config.video_token_id])
return inputs_embeds

def get_input_embeddings_v0(
self,
input_ids: torch.Tensor,
image_input: Optional[tuple[torch.Tensor, ...]] = None,
video_input: Optional[tuple[torch.Tensor, ...]] = None,
) -> torch.Tensor:

inputs_embeds = self.get_input_embeddings(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)

if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -1303,22 +1329,25 @@ def forward(
if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)

# We need to check for usage of mrope here in case there is
# multimodal data.
# TODO (ywang96): move this to model runner in V1.
if multimodal_embeddings is not None and uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")

inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)

if image_input is None and video_input is None:
inputs_embeds = None
else:
if uses_mrope(self.config):
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.get_input_embeddings_v0(
input_ids,
image_input=image_input,
video_input=video_input)
input_ids = None

hidden_states = self.language_model.model(
input_ids=input_ids,
Expand Down

0 comments on commit 2f40e8c

Please sign in to comment.