Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] support input embeddings for qwen2vl #8856

Merged
Merged
Prev Previous commit
Next Next commit
Refactor _expand_pad_tokens function for Qwen2-VL model
  • Loading branch information
whyiug committed Sep 29, 2024
commit 7cf86b60158c07141dc79aa65aabdfa106ece239
86 changes: 53 additions & 33 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, Type,
TypedDict, Union)
TypedDict, Union, Callable, Any)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -756,6 +756,48 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens


def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
data_type_key: str, image_processor: Any,
prompt_token_ids: List[int]) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).

Args:
inputs (list): The multi-modal inputs (e.g., images or videos).
token_id (int): The token ID used to represent the multi-modal input.
make_batched_fn (Callable): A function to batch the inputs.
data_type_key (str): The type of the multi-modal input ("image" or "video").
image_processor (Any): The image processor used to process the inputs.
prompt_token_ids (List[int]): The list of token IDs in the prompt.

Returns:
List[int]: The list of token IDs with expanded pad tokens for the multi-modal inputs.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data


def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None)
Expand Down Expand Up @@ -792,32 +834,6 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
)["input_ids"]

# Expand image pad tokens.
def expand_pad_tokens(inputs, token_id, make_batched_fn, data_type_key):
indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[indices[cnt - 1] +
1:indices[cnt]]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(token_id
for _ in range(num_tokens))
prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:])
return prompt_token_ids_with_data

if image_inputs is not None:
if isinstance(image_inputs, dict):
Expand All @@ -838,14 +854,18 @@ def expand_pad_tokens(inputs, token_id, make_batched_fn, data_type_key):
prompt_token_ids_with_image.append(token)
prompt_token_ids = prompt_token_ids_with_image
else:
prompt_token_ids = expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images, "image")
prompt_token_ids = _expand_pad_tokens(image_inputs,
hf_config.image_token_id,
make_batched_images, "image",
image_processor,
prompt_token_ids)

if video_inputs is not None:
prompt_token_ids = expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos, "video")
prompt_token_ids = _expand_pad_tokens(video_inputs,
hf_config.video_token_id,
make_batched_videos, "video",
image_processor,
prompt_token_ids)

return LLMInputs(
prompt_token_ids=prompt_token_ids,
Expand Down
Loading