Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] support input embeddings for qwen2vl #8856

Merged
Merged
Prev Previous commit
Next Next commit
refactor _expand_pad_tokens function
  • Loading branch information
whyiug committed Sep 27, 2024
commit d906fe7bdca492c4214cc9aeb1ffe4824b3bcc62
127 changes: 69 additions & 58 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from functools import lru_cache, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, Type,
TypedDict, Union)
TypedDict, Union, Callable)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -753,6 +753,58 @@ def _get_llm_num_vision_tokens(
return llm_num_vision_tokens


def _expand_pad_tokens(
inputs: list,
token_id: int,
make_batched_fn: Callable,
data_type_key: str,
) -> List[int]:
"""
Expand pad tokens for multi-modal inputs (e.g., images or videos).

Args:
inputs:
The multi-modal inputs (e.g., image_inputs or video_inputs).
token_id (int):
image_token_id or video_token_id.
make_batched_fn:
"make_batched_images" or "make_batched_videos".
data_type_key:
The type of the multi-modal input ("image" or "video").

Returns:
List[int]: The list of token IDs with expanded pad tokens.
"""
indices = [
idx for idx, token in enumerate(prompt_token_ids) if token == token_id
]
inputs = make_batched_fn(inputs)
assert len(indices) == len(inputs)

prompt_token_ids_with_data = []
for cnt, data in enumerate(inputs):
num_tokens = _get_llm_num_vision_tokens(
[data] if data_type_key == "image" else data,
data_type_key=data_type_key,
image_processor=image_processor,
)
if cnt == 0:
end_idx = indices[cnt]
non_data_tokens = prompt_token_ids[:end_idx]
else:
non_data_tokens = prompt_token_ids[
indices[cnt - 1] + 1:indices[cnt]
]
prompt_token_ids_with_data.extend(non_data_tokens)
prompt_token_ids_with_data.extend(
token_id for _ in range(num_tokens)
)
prompt_token_ids_with_data.extend(
prompt_token_ids[indices[-1] + 1:]
)
return prompt_token_ids_with_data


def input_processor_for_qwen2_vl(ctx: InputContext,
llm_inputs: LLMInputs) -> LLMInputs:
multi_modal_data = llm_inputs.get("multi_modal_data", None)
Expand Down Expand Up @@ -790,12 +842,12 @@ def input_processor_for_qwen2_vl(ctx: InputContext,

# Expand image pad tokens.
if image_inputs is not None:
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_cnt = len(image_indices)
embed_dim = image_inputs.get('image_embeds').size(0)
assert embed_dim % image_cnt == 0
Expand All @@ -805,63 +857,22 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
prompt_token_ids_with_image.extend([token] * num_pad_tokens)
else:
prompt_token_ids_with_image.append(token)

prompt_token_ids = prompt_token_ids_with_image
else:
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)

prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
end_idx = image_indices[image_cnt]
non_image_tokens = prompt_token_ids[:end_idx]
else:
non_image_tokens = prompt_token_ids[
image_indices[image_cnt - 1] + 1:image_indices[image_cnt]
]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(
prompt_token_ids[image_indices[-1] + 1:]
prompt_token_ids = _expand_pad_tokens(
image_inputs,
hf_config.image_token_id,
make_batched_images,
"image"
)
prompt_token_ids = prompt_token_ids_with_image

# Expand video pad tokens.
if video_inputs is not None:
video_indices = [
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.video_token_id
]
video_inputs = make_batched_videos(video_inputs)
assert len(video_indices) == len(video_inputs)

prompt_token_ids_with_video = []
for video_cnt, video in enumerate(video_inputs):
num_video_tokens = _get_llm_num_vision_tokens(
video,
data_type_key="video",
image_processor=image_processor,
)
if video_cnt == 0:
end_idx = video_indices[video_cnt]
non_video_tokens = prompt_token_ids[:end_idx]
else:
non_video_tokens = prompt_token_ids[
video_indices[video_cnt - 1] + 1:video_indices[video_cnt]
]
prompt_token_ids_with_video.extend(non_video_tokens)
prompt_token_ids_with_video.extend(
hf_config.video_token_id for _ in range(num_video_tokens))
prompt_token_ids_with_video.extend(
prompt_token_ids[video_indices[-1] + 1:]
)
prompt_token_ids = prompt_token_ids_with_video
prompt_token_ids = _expand_pad_tokens(
video_inputs,
hf_config.video_token_id,
make_batched_videos,
"video"
)

return LLMInputs(
prompt_token_ids=prompt_token_ids,
Expand Down
Loading