Skip to content

Commit

Permalink
[image-to-text pipeline] Add conditional text support + GIT (huggingf…
Browse files Browse the repository at this point in the history
…ace#23362)

* First draft

* Remove print statements

* Add conditional generation

* Add more tests

* Remove scripts

* Remove BLIP specific linkes

* Add support for pix2struct

* Add fast test

* Address comment

* Fix style
  • Loading branch information
NielsRogge authored and sheonhan committed Jun 1, 2023
1 parent dce6509 commit 0b69d43
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@
[
("blip", "BlipForConditionalGeneration"),
("blip-2", "Blip2ForConditionalGeneration"),
("git", "GitForCausalLM"),
("pix2struct", "Pix2StructForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"),
]
)
Expand Down
49 changes: 45 additions & 4 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING

if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -56,8 +58,13 @@ def __init__(self, *args, **kwargs):
TF_MODEL_FOR_VISION_2_SEQ_MAPPING if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING
)

def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None):
forward_kwargs = {}
preprocess_params = {}

if prompt is not None:
preprocess_params["prompt"] = prompt

if generate_kwargs is not None:
forward_kwargs["generate_kwargs"] = generate_kwargs
if max_new_tokens is not None:
Expand All @@ -69,7 +76,7 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None):
" please use only one"
)
forward_kwargs["generate_kwargs"]["max_new_tokens"] = max_new_tokens
return {}, forward_kwargs, {}
return preprocess_params, forward_kwargs, {}

def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Image"]], **kwargs):
"""
Expand Down Expand Up @@ -98,9 +105,43 @@ def __call__(self, images: Union[str, List[str], "Image.Image", List["Image.Imag
"""
return super().__call__(images, **kwargs)

def preprocess(self, image):
def preprocess(self, image, prompt=None):
image = load_image(image)
model_inputs = self.image_processor(images=image, return_tensors=self.framework)

if prompt is not None:
if not isinstance(prompt, str):
raise ValueError(
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
"Note also that one single text can be provided for conditional image to text generation."
)

model_type = self.model.config.model_type

if model_type == "git":
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
input_ids = [self.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
model_inputs.update({"input_ids": input_ids})

elif model_type == "pix2struct":
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)

elif model_type != "vision-encoder-decoder":
# vision-encoder-decoder does not support conditional generation
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
model_inputs.update(text_inputs)

else:
raise ValueError(f"Model type {model_type} does not support conditional text generation")

else:
model_inputs = self.image_processor(images=image, return_tensors=self.framework)

if self.model.config.model_type == "git" and prompt is None:
model_inputs["input_ids"] = None

return model_inputs

def _forward(self, model_inputs, generate_kwargs=None):
Expand Down
76 changes: 76 additions & 0 deletions tests/pipelines/test_pipelines_image_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import unittest

import requests

from transformers import MODEL_FOR_VISION_2_SEQ_MAPPING, TF_MODEL_FOR_VISION_2_SEQ_MAPPING, is_vision_available
from transformers.pipelines import pipeline
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_vision, slow
Expand Down Expand Up @@ -125,6 +127,15 @@ def test_small_model_pt(self):
],
)

@require_torch
def test_small_model_pt_conditional(self):
pipe = pipeline("image-to-text", model="hf-internal-testing/tiny-random-BlipForConditionalGeneration")
image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
prompt = "a photo of"

outputs = pipe(image, prompt=prompt)
self.assertTrue(outputs[0]["generated_text"].startswith(prompt))

@slow
@require_torch
def test_large_model_pt(self):
Expand All @@ -143,6 +154,71 @@ def test_large_model_pt(self):
],
)

@slow
@require_torch
def test_generation_pt_blip(self):
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)

outputs = pipe(image)
self.assertEqual(outputs, [{"generated_text": "a pink pokemon pokemon with a blue shirt and a blue shirt"}])

@slow
@require_torch
def test_generation_pt_git(self):
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
url = "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/pokemon.png"
image = Image.open(requests.get(url, stream=True).raw)

outputs = pipe(image)
self.assertEqual(outputs, [{"generated_text": "a cartoon of a purple character."}])

@slow
@require_torch
def test_conditional_generation_pt_blip(self):
pipe = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "a photography of"

outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "a photography of a volcano"}])

with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])

@slow
@require_torch
def test_conditional_generation_pt_git(self):
pipe = pipeline("image-to-text", model="microsoft/git-base-coco")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "a photo of a"

outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "a photo of a tent with a tent and a tent in the background."}])

with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])

@slow
@require_torch
def test_conditional_generation_pt_pix2struct(self):
pipe = pipeline("image-to-text", model="google/pix2struct-ai2d-base")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)

prompt = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"

outputs = pipe(image, prompt=prompt)
self.assertEqual(outputs, [{"generated_text": "ash cloud"}])

with self.assertRaises(ValueError):
outputs = pipe([image, image], prompt=[prompt, prompt])

@slow
@require_tf
def test_large_model_tf(self):
Expand Down

0 comments on commit 0b69d43

Please sign in to comment.