Skip to content

Commit

Permalink
change logic handling of single prompt ans multiple images
Browse files Browse the repository at this point in the history
  • Loading branch information
yonigozlan committed Oct 15, 2024
1 parent 7a6eb23 commit 66efd2a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
7 changes: 0 additions & 7 deletions src/transformers/models/paligemma/processing_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""

import logging
import warnings
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
Expand Down Expand Up @@ -97,12 +96,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i
image_token (`str`): The image token.
num_images (`int`): Number of images in the prompt.
"""
if image_token in prompt:
warnings.warn(
f"The image token {image_token} is already present in the prompt. No need to manually add {image_token} in the prompt for this model."
f" Removing all {image_token} and adding ({image_token}) * image_seq_len * num_images at the start of the prompt."
)
prompt = prompt.replace(image_token, "")
return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"


Expand Down
24 changes: 18 additions & 6 deletions src/transformers/pipelines/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def retrieve_images_in_chat(chat: dict, images: Optional[Union[str, List[str], "
if isinstance(content, dict) and content.get("type") == "image":
if "image" in content:
retrieved_images.append(content["image"])
elif "url" in content:
retrieved_images.append(content["url"])
elif "path" in content:
retrieved_images.append(content["path"])
elif idx_images < len(images):
retrieved_images.append(images[idx_images])
idx_images += 1
Expand Down Expand Up @@ -128,7 +132,7 @@ class ImageTextToTextPipeline(Pipeline):
>>> "content": [
>>> {
>>> "type": "image",
>>> "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
>>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
>>> },
>>> {"type": "text", "text": "Describe this image."},
>>> ],
Expand All @@ -143,7 +147,7 @@ class ImageTextToTextPipeline(Pipeline):
>>> pipe(text=messages, max_new_tokens=20, return_full_text=False)
[{'input_text': [{'role': 'user',
'content': [{'type': 'image',
'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
{'type': 'text', 'text': 'Describe this image.'}]},
{'role': 'assistant',
'content': [{'type': 'text', 'text': 'There is a dog and'}]}],
Expand Down Expand Up @@ -298,7 +302,7 @@ def __call__(
if not isinstance(images, (list, tuple)):
images = [images]
if isinstance(text, str):
text = [text] * len(images)
text = [text]
if not isinstance(text[0], str):
raise ValueError("The pipeline does not support nested lists of prompts.")

Expand Down Expand Up @@ -335,10 +339,18 @@ def __call__(
images_reorganized.append(images[:num_images])
images = images[num_images:]
images = images_reorganized
# After reorganizing, these should be the same
if len(images) != len(text):
raise ValueError("The number of images and text should be the same.")
elif len(text) == 1 and len(images) > 1:
logger.warning(
"The pipeline detected multiple images for one prompt, but no image tokens in the prompt. "
"The prompt will be repeated for each image."
)
text = [text[0]] * len(images)

# After reorganizing, these should be the same
if len(text) > 1 and len(images) != len(text):
raise ValueError(
"Undefined behavior, please check the number of images and prompts, and nest the images to match the prompts."
)
return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs)

def preprocess(
Expand Down
10 changes: 5 additions & 5 deletions tests/pipelines/test_pipelines_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_model_pt_chat_template_continue_final_message(self):
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
Expand All @@ -195,7 +195,7 @@ def test_model_pt_chat_template_continue_final_message(self):
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
Expand All @@ -208,7 +208,7 @@ def test_model_pt_chat_template_continue_final_message(self):
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_model_pt_chat_template_new_text(self):
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
Expand All @@ -254,7 +254,7 @@ def test_model_pt_chat_template_new_text(self):
"content": [
{
"type": "image",
"image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
"url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
},
{"type": "text", "text": "Describe this image."},
],
Expand Down

0 comments on commit 66efd2a

Please sign in to comment.