Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] support input embeddings for qwen2vl #8856

Merged
Merged
Prev Previous commit
Next Next commit
support input embeddings for qwen2vl
  • Loading branch information
whyiug committed Sep 27, 2024
commit 797fee9fc51aee35837f54bd49f66411a6008e0d
79 changes: 41 additions & 38 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class Qwen2VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`

`hidden_size` must match the hidden size of language model backbone.
"""

Expand Down Expand Up @@ -577,13 +576,11 @@ def mm_input_mapper_for_qwen2_vl(
data_type_key: str,
) -> MultiModalInputs:
"""Input mapper for Qwen2-VL."""
if isinstance(data, torch.Tensor):
pass
# return MultiModalInputs({
# "image_embeds": data,
# # "image_grid_thw": torch.tensor([[1, 24, 82]], dtype=torch.int32),
# })

if data_type_key == "image" and isinstance(data, dict):
return MultiModalInputs({
"image_embeds": data.get("image_embeds"),
"image_grid_thw": data.get("image_grid_thw"),
})
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model, trust_remote_code=model_config.trust_remote_code)
Expand Down Expand Up @@ -765,13 +762,6 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
image_inputs = multi_modal_data.get("image", None)
video_inputs = multi_modal_data.get("video", None)

if isinstance(image_inputs, torch.Tensor):
pass
# return LLMInputs(
# prompt_token_ids=prompt_token_ids,
# # prompt=llm_inputs["prompt"],
# multi_modal_data=multi_modal_data,
# )
processor = cached_get_processor(ctx.model_config.model)
image_processor = processor.image_processor
hf_config = ctx.get_hf_config(Qwen2VLConfig)
Expand Down Expand Up @@ -804,27 +794,40 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
idx for idx, token in enumerate(prompt_token_ids)
if token == hf_config.image_token_id
]
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)

prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
if isinstance(image_inputs, dict):
prompt_token_ids_with_image = []
image_cnt = len(image_indices)
embed_dim = image_inputs.get('image_embeds').size(0)
assert embed_dim % image_cnt == 0
num_pad_tokens = embed_dim // image_cnt
for idx, token in enumerate(prompt_token_ids):
if idx in image_indices:
prompt_token_ids_with_image.extend([token] * num_pad_tokens)
else:
prompt_token_ids_with_image.append(token)

else:
image_inputs = make_batched_images(image_inputs)
assert len(image_indices) == len(image_inputs)

prompt_token_ids_with_image = []
for image_cnt, image in enumerate(image_inputs):
num_image_tokens = _get_llm_num_vision_tokens(
[image],
data_type_key="image",
image_processor=image_processor,
)
if image_cnt == 0:
non_image_tokens = prompt_token_ids[:image_indices[image_cnt]]
else:
non_image_tokens = prompt_token_ids[image_indices[image_cnt -
1] +
1:image_indices[image_cnt]]
prompt_token_ids_with_image.extend(non_image_tokens)
prompt_token_ids_with_image.extend(
hf_config.image_token_id for _ in range(num_image_tokens))
prompt_token_ids_with_image.extend(prompt_token_ids[image_indices[-1] +
1:])
prompt_token_ids = prompt_token_ids_with_image

# Expand video pad tokens.
Expand Down Expand Up @@ -986,9 +989,9 @@ def _parse_and_validate_video_input(
def _process_image_input(self,
image_input: Qwen2VLImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
return image_input["data"].type(self.visual.dtype)

pixel_values = image_input["pixel_values"].type(self.visual.dtype)
pixel_values = image_input["data"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values,
grid_thw=image_input["image_grid_thw"])
return image_embeds
Expand Down
Loading