From 0b69d43c8f351e52e3be1ea484449aa190d9bf57 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 22 May 2023 21:45:50 +0200 Subject: [PATCH] [image-to-text pipeline] Add conditional text support + GIT (#23362) * First draft * Remove print statements * Add conditional generation * Add more tests * Remove scripts * Remove BLIP specific linkes * Add support for pix2struct * Add fast test * Address comment * Fix style --- src/transformers/models/auto/modeling_auto.py | 2 + src/transformers/pipelines/image_to_text.py | 49 +++++++++++- .../pipelines/test_pipelines_image_to_text.py | 76 +++++++++++++++++++ 3 files changed, 123 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b7a45fb504f618..0388380c3cfcf6 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -530,6 +530,8 @@ [ ("blip", "BlipForConditionalGeneration"), ("blip-2", "Blip2ForConditionalGeneration"), + ("git", "GitForCausalLM"), + ("pix2struct", "Pix2StructForConditionalGeneration"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"), ] ) diff --git a/src/transformers/pipelines/image_to_text.py b/src/transformers/pipelines/image_to_text.py index f34dad3cef8142..1c082c2ecb38b4 100644 --- a/src/transformers/pipelines/image_to_text.py +++ b/src/transformers/pipelines/image_to_text.py @@ -20,6 +20,8 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING logger = logging.get_logger(__name__) @@ -56,8 +58,13 @@ def __init__(self, *args, **kwargs): TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING ) - def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None): + def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None): forward_kwargs = {} + preprocess_params = {} + + if prompt is not None: + preprocess_params["prompt"] = prompt + if generate_kwargs is not None: forward_kwargs["generate_kwargs"] = generate_kwargs if max_new_tokens is not None: @@ -69,7 +76,7 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None): " please use only one" ) forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens - return {}, forward_kwargs, {} + return preprocess_params, forward_kwargs, {} def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs): """ @@ -98,9 +105,43 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag """ return super().__call__(images, **kwargs) - def preprocess(self, image): + def preprocess(self, image, prompt=None): image = load_image(image) - model_inputs = self.image_processor(images=image, return_tensors=self.framework) + + if prompt is not None: + if not isinstance(prompt, str): + raise ValueError( + f"Received an invalid text input, got - {type(prompt)} - but expected a single string. " + "Note also that one single text can be provided for conditional image to text generation." + ) + + model_type = self.model.config.model_type + + if model_type == "git": + model_inputs = self.image_processor(images=image, return_tensors=self.framework) + input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids + input_ids = [self.tokenizer.cls_token_id] + input_ids + input_ids = torch.tensor(input_ids).unsqueeze(0) + model_inputs.update({"input_ids": input_ids}) + + elif model_type == "pix2struct": + model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework) + + elif model_type != "vision-encoder-decoder": + # vision-encoder-decoder does not support conditional generation + model_inputs = self.image_processor(images=image, return_tensors=self.framework) + text_inputs = self.tokenizer(prompt, return_tensors=self.framework) + model_inputs.update(text_inputs) + + else: + raise ValueError(f"Model type {model_type} does not support conditional text generation") + + else: + model_inputs = self.image_processor(images=image, return_tensors=self.framework) + + if self.model.config.model_type == "git" and prompt is None: + model_inputs["input_ids"] = None + return model_inputs def _forward(self, model_inputs, generate_kwargs=None): diff --git a/tests/pipelines/test_pipelines_image_to_text.py b/tests/pipelines/test_pipelines_image_to_text.py index 97fe3a398f5813..2a73206f1ba600 100644 --- a/tests/pipelines/test_pipelines_image_to_text.py +++ b/tests/pipelines/test_pipelines_image_to_text.py @@ -14,6 +14,8 @@ import unittest +import requests + from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available from transformers.pipelines import pipeline from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow @@ -125,6 +127,15 @@ def test_small_model_pt(self): ], ) + @require_torch + def test_small_model_pt_conditional(self): + pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-BlipForConditionalGeneration") + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + prompt = "a photo of" + + outputs = pipe(image, prompt=prompt) + self.assertTrue(outputs[0]["generated_text"].startswith(prompt)) + @slow @require_torch def test_large_model_pt(self): @@ -143,6 +154,71 @@ def test_large_model_pt(self): ], ) + @slow + @require_torch + def test_generation_pt_blip(self): + pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") + url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" + image = Image.open(requests.get(url, stream=True).raw) + + outputs = pipe(image) + self.assertEqual(outputs, [{"generated_text": "a pink pokemon pokemon with a blue shirt and a blue shirt"}]) + + @slow + @require_torch + def test_generation_pt_git(self): + pipe = pipeline("image-to-text", model="microsoft/git-base-coco") + url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png" + image = Image.open(requests.get(url, stream=True).raw) + + outputs = pipe(image) + self.assertEqual(outputs, [{"generated_text": "a cartoon of a purple character."}]) + + @slow + @require_torch + def test_conditional_generation_pt_blip(self): + pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + prompt = "a photography of" + + outputs = pipe(image, prompt=prompt) + self.assertEqual(outputs, [{"generated_text": "a photography of a volcano"}]) + + with self.assertRaises(ValueError): + outputs = pipe([image, image], prompt=[prompt, prompt]) + + @slow + @require_torch + def test_conditional_generation_pt_git(self): + pipe = pipeline("image-to-text", model="microsoft/git-base-coco") + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + prompt = "a photo of a" + + outputs = pipe(image, prompt=prompt) + self.assertEqual(outputs, [{"generated_text": "a photo of a tent with a tent and a tent in the background."}]) + + with self.assertRaises(ValueError): + outputs = pipe([image, image], prompt=[prompt, prompt]) + + @slow + @require_torch + def test_conditional_generation_pt_pix2struct(self): + pipe = pipeline("image-to-text", model="google/pix2struct-ai2d-base") + url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + prompt = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud" + + outputs = pipe(image, prompt=prompt) + self.assertEqual(outputs, [{"generated_text": "ash cloud"}]) + + with self.assertRaises(ValueError): + outputs = pipe([image, image], prompt=[prompt, prompt]) + @slow @require_tf def test_large_model_tf(self):