From 66efd2a72620786c7e9fdaf1f9b4797b741cf655 Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Tue, 15 Oct 2024 13:19:24 +0000 Subject: [PATCH] change logic handling of single prompt ans multiple images --- .../models/paligemma/processing_paligemma.py | 7 ------ .../pipelines/image_text_to_text.py | 24 ++++++++++++++----- .../test_pipelines_image_text_to_text.py | 10 ++++---- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/paligemma/processing_paligemma.py b/src/transformers/models/paligemma/processing_paligemma.py index fa6594398a2343..c4be519b9c8be2 100644 --- a/src/transformers/models/paligemma/processing_paligemma.py +++ b/src/transformers/models/paligemma/processing_paligemma.py @@ -17,7 +17,6 @@ """ import logging -import warnings from typing import List, Optional, Union from ...feature_extraction_utils import BatchFeature @@ -97,12 +96,6 @@ def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_i image_token (`str`): The image token. num_images (`int`): Number of images in the prompt. """ - if image_token in prompt: - warnings.warn( - f"The image token {image_token} is already present in the prompt. No need to manually add {image_token} in the prompt for this model." - f" Removing all {image_token} and adding ({image_token}) * image_seq_len * num_images at the start of the prompt." - ) - prompt = prompt.replace(image_token, "") return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n" diff --git a/src/transformers/pipelines/image_text_to_text.py b/src/transformers/pipelines/image_text_to_text.py index 4e6c82d723a560..0231b13732d2cb 100644 --- a/src/transformers/pipelines/image_text_to_text.py +++ b/src/transformers/pipelines/image_text_to_text.py @@ -85,6 +85,10 @@ def retrieve_images_in_chat(chat: dict, images: Optional[Union[str, List[str], " if isinstance(content, dict) and content.get("type") == "image": if "image" in content: retrieved_images.append(content["image"]) + elif "url" in content: + retrieved_images.append(content["url"]) + elif "path" in content: + retrieved_images.append(content["path"]) elif idx_images < len(images): retrieved_images.append(images[idx_images]) idx_images += 1 @@ -128,7 +132,7 @@ class ImageTextToTextPipeline(Pipeline): >>> "content": [ >>> { >>> "type": "image", - >>> "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + >>> "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", >>> }, >>> {"type": "text", "text": "Describe this image."}, >>> ], @@ -143,7 +147,7 @@ class ImageTextToTextPipeline(Pipeline): >>> pipe(text=messages, max_new_tokens=20, return_full_text=False) [{'input_text': [{'role': 'user', 'content': [{'type': 'image', - 'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}, + 'url': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'}, {'type': 'text', 'text': 'Describe this image.'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'There is a dog and'}]}], @@ -298,7 +302,7 @@ def __call__( if not isinstance(images, (list, tuple)): images = [images] if isinstance(text, str): - text = [text] * len(images) + text = [text] if not isinstance(text[0], str): raise ValueError("The pipeline does not support nested lists of prompts.") @@ -335,10 +339,18 @@ def __call__( images_reorganized.append(images[:num_images]) images = images[num_images:] images = images_reorganized - # After reorganizing, these should be the same - if len(images) != len(text): - raise ValueError("The number of images and text should be the same.") + elif len(text) == 1 and len(images) > 1: + logger.warning( + "The pipeline detected multiple images for one prompt, but no image tokens in the prompt. " + "The prompt will be repeated for each image." + ) + text = [text[0]] * len(images) + # After reorganizing, these should be the same + if len(text) > 1 and len(images) != len(text): + raise ValueError( + "Undefined behavior, please check the number of images and prompts, and nest the images to match the prompts." + ) return super().__call__([ImageText(image, text_single) for image, text_single in zip(images, text)], **kwargs) def preprocess( diff --git a/tests/pipelines/test_pipelines_image_text_to_text.py b/tests/pipelines/test_pipelines_image_text_to_text.py index 1e33436e9fda2a..c13514c2379183 100644 --- a/tests/pipelines/test_pipelines_image_text_to_text.py +++ b/tests/pipelines/test_pipelines_image_text_to_text.py @@ -172,7 +172,7 @@ def test_model_pt_chat_template_continue_final_message(self): "content": [ { "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], @@ -195,7 +195,7 @@ def test_model_pt_chat_template_continue_final_message(self): "content": [ { "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], @@ -208,7 +208,7 @@ def test_model_pt_chat_template_continue_final_message(self): "content": [ { "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], @@ -237,7 +237,7 @@ def test_model_pt_chat_template_new_text(self): "content": [ { "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ], @@ -254,7 +254,7 @@ def test_model_pt_chat_template_new_text(self): "content": [ { "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", + "url": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", }, {"type": "text", "text": "Describe this image."}, ],