Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image text to text pipeline #34170

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c292fac
Standardize image-text-to-text-models-output
yonigozlan Aug 6, 2024
09e9a8b
nit var name post_process_image_text_to_text udop
yonigozlan Oct 11, 2024
8b635ef
nit fix deprecation warnings
yonigozlan Oct 11, 2024
abaf253
Add image-text-to-text pipeline
yonigozlan Oct 11, 2024
1b3f797
add support for image url in chat template for pipeline
yonigozlan Oct 14, 2024
c6d63a9
Reformat to be fully compatible with chat templates
yonigozlan Oct 14, 2024
4e87977
Add tests chat template
yonigozlan Oct 15, 2024
1cd790d
Fix imports and tests
yonigozlan Oct 15, 2024
ae5eeca
Add pipeline tag
yonigozlan Oct 15, 2024
d229d39
change logic handling of single prompt ans multiple images
yonigozlan Oct 15, 2024
7cb7a50
add pipeline mapping to models
yonigozlan Oct 15, 2024
c151d8f
fix batched inference
yonigozlan Oct 15, 2024
5b102b9
fix tests
yonigozlan Oct 15, 2024
9130387
Add manual batching for preprocessing
yonigozlan Oct 16, 2024
fcadd90
Fix outputs with nested images
yonigozlan Oct 16, 2024
0b8152c
Add support for all common processing kwargs
yonigozlan Oct 17, 2024
9f1eb75
Add default padding when multiple text inputs (batch size>1)
yonigozlan Oct 17, 2024
ce6c292
nit change version deprecation warning
yonigozlan Oct 17, 2024
c12d2b7
Add support for text only inference
yonigozlan Oct 17, 2024
143e9ef
add chat_template warnings
yonigozlan Oct 21, 2024
93ddcee
Add pipeline tests and add copied from post process function
yonigozlan Oct 24, 2024
2875f6d
Fix batched pipeline tests
yonigozlan Oct 24, 2024
20545ca
nit
yonigozlan Oct 24, 2024
f947348
Fix pipeline tests blip2
yonigozlan Oct 24, 2024
dc123ca
remove unnecessary max_new_tokens
yonigozlan Oct 24, 2024
6e2db13
revert processing kosmos2 and remove unnecessary max_new_tokens
yonigozlan Oct 24, 2024
f2952e7
fix pipeline tests idefics
yonigozlan Oct 24, 2024
23a08a4
Force try loading processor if pipeline supports it
yonigozlan Oct 25, 2024
248823b
revert load_processor change
yonigozlan Oct 25, 2024
cf6c13e
hardcode loading only processor
yonigozlan Oct 25, 2024
2be9868
remove unnecessary try except
yonigozlan Oct 25, 2024
5dd4f4a
skip imagetexttotext tests for kosmos2 as tiny model causes problems
yonigozlan Oct 25, 2024
f89ae84
Make code clearer
yonigozlan Oct 25, 2024
3ffb6b2
Address review comments
yonigozlan Oct 25, 2024
9449911
remove preprocessing logic from pipeline
yonigozlan Oct 28, 2024
8e5b7fd
fix fuyu
yonigozlan Oct 28, 2024
225950f
add BC resize fuyu
yonigozlan Oct 28, 2024
72e508d
Move post_process_image_text_to_text to ProcessorMixin
yonigozlan Oct 28, 2024
4ebdbbd
add guard in post_process
yonigozlan Oct 28, 2024
5aee30b
fix zero shot object detection pipeline
yonigozlan Oct 29, 2024
a3442c4
add support for generator input in pipeline
yonigozlan Oct 29, 2024
64acc98
nit
yonigozlan Oct 29, 2024
ccbac25
change default image-text-to-text model to llava onevision
yonigozlan Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/ja/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### VisualQuestionAnsweringPipeline

[[autodoc]] VisualQuestionAnsweringPipeline
Expand Down
6 changes: 6 additions & 0 deletions docs/source/zh/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

### ImageTextToTextPipeline

[[autodoc]] ImageTextToTextPipeline
- __call__
- all

### MaskGenerationPipeline

[[autodoc]] MaskGenerationPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageTextToTextPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
"JsonPipelineDataFormat",
Expand Down Expand Up @@ -5794,6 +5795,7 @@
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageTextToTextPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
JsonPipelineDataFormat,
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,27 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def load_images(
images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
"""Loads images, handling different levels of nesting.

Args:
images: A single image, a list of images, or a list of lists of images to load.
timeout: Timeout for loading images.

Returns:
A single image, a list of images, a list of lists of images.
"""
if isinstance(images, (list, tuple)):
if len(images) and isinstance(images[0], (list, tuple)):
return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
else:
return [load_image(image, timeout=timeout) for image in images]
else:
return load_image(images, timeout=timeout)


def validate_preprocess_arguments(
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
("oneformer", ("OneFormerImageProcessor",)),
("owlv2", ("Owlv2ImageProcessor",)),
("owlvit", ("OwlViTImageProcessor",)),
("paligemma", ("SiglipImageProcessor",)),
("perceiver", ("PerceiverImageProcessor",)),
("pix2struct", ("Pix2StructImageProcessor",)),
("pixtral", ("PixtralImageProcessor",)),
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/models/donut/processing_donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Processor class for Donut.
"""

import logging
import re
import warnings
from contextlib import contextmanager
Expand All @@ -24,12 +25,16 @@
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class DonutProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.getLogger(__name__)


class DonutProcessor(ProcessorMixin):
r"""
Constructs a Donut processor which wraps a Donut image processor and an XLMRoBERTa tokenizer into a single
Expand Down Expand Up @@ -70,6 +75,7 @@ def __init__(self, image_processor=None, tokenizer=None, **kwargs):
self.current_processor = self.image_processor
self._in_target_context_manager = False

@deprecate_kwarg(old_name="legacy", version="5.0.0")
def __call__(
self,
images: ImageInput = None,
Expand All @@ -85,6 +91,15 @@ def __call__(
[`~DonutTokenizer.__call__`]. Please refer to the doctsring of the above two methods for more information.
"""
# For backward compatibility
legacy = kwargs.pop("legacy", True)
if legacy:
# With `add_special_tokens=True`, the performance of donut are degraded when working with both images and text.
logger.warning_once(
"Legacy behavior is being used. The new behavior with legacy=False will be enabled in the future."
"In the new behavior, if both images and text are provided, the default value of `add_special_tokens` "
"will be changed to `False` when calling the tokenizer if `add_special_tokens` is unset."
)
yonigozlan marked this conversation as resolved.
Show resolved Hide resolved

if self._in_target_context_manager:
return self.current_processor(images, text, **kwargs)

Expand All @@ -100,6 +115,8 @@ def __call__(
if images is not None:
inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
if text is not None:
if not legacy and images is not None:
output_kwargs["text_kwargs"].setdefault("add_special_tokens", False)
encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])

if text is None:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/fuyu/image_processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import numpy as np

from ...image_processing_utils import BaseImageProcessor, BatchFeature
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
pad,
resize,
Expand Down Expand Up @@ -475,6 +475,7 @@ def preprocess(
input_data_format = infer_channel_dimension_format(batch_images[0][0])

original_image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
size = get_size_dict(size) # for BC

if do_resize:
batch_images = [
Expand Down
30 changes: 28 additions & 2 deletions src/transformers/models/fuyu/processing_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def _tokenize_prompts_with_image_and_batch(
bos_token = tokenizer.vocab["|ENDOFTEXT|"]
prompts_tokens = [[[bos_token] + x for x in prompt_seq] for prompt_seq in prompts_tokens]
if add_beginning_of_answer_token:
boa = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
beginning_of_answer = tokenizer.vocab[BEGINNING_OF_ANSWER_STRING]
# Only add bbox open token to the last subsequence since that is what will be completed
for token_seq in prompts_tokens:
token_seq[-1].append(boa)
token_seq[-1].append(beginning_of_answer)

# Now we have a list of list of tokens which each list has a different
# size. We want to extend this list to:
Expand Down Expand Up @@ -682,6 +682,32 @@ def tokens_to_points(tokens, original_size):

return results

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-processes the output of `FuyuForConditionalGeneration` to only return the text output.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
containing the token ids of the generated sequences.

Returns:
`List[str]`: The decoded text output.
"""
beginning_of_answer = self.tokenizer.convert_tokens_to_ids(BEGINNING_OF_ANSWER_STRING)
# get boa index for each outputted sequence tensor
# start all generated sequences from the beginning of the answer token, pad to have consistent length
unpadded_output_sequences = [
seq[(seq == beginning_of_answer).nonzero(as_tuple=True)[0] + 1 :] for seq in generated_outputs
]
max_len = max(len(seq) for seq in unpadded_output_sequences)
# convert to torch and pad sequences
padded_output_sequences = torch.full((len(unpadded_output_sequences), max_len), self.pad_token_id)
for i, seq in enumerate(unpadded_output_sequences):
padded_output_sequences[i, : len(seq)] = torch.tensor(seq)

return self.batch_decode(padded_output_sequences, skip_special_tokens=True)

def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/models/git/processing_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@
Image/Text processor class for GIT
"""

import logging
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class GitProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {}


logger = logging.getLogger(__name__)


class GitProcessor(ProcessorMixin):
r"""
Constructs a GIT processor which wraps a CLIP image processor and a BERT tokenizer into a single processor.
Expand All @@ -50,6 +55,7 @@ def __init__(self, image_processor, tokenizer):
super().__init__(image_processor, tokenizer)
self.current_processor = self.image_processor

@deprecate_kwarg(old_name="legacy", version="5.0.0")
def __call__(
self,
images: Optional[ImageInput] = None,
Expand Down Expand Up @@ -91,6 +97,14 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning(
"Legacy behavior is being used. The new behavior with legacy=False will be enabled in the future."
"In the new behavior, if both images and text are provided, the last token (EOS token) "
"of the input_ids and attention_mask tensors will be removed."
)

if text is None and images is None:
raise ValueError("You have to specify either text or images. Both cannot be none.")

Expand All @@ -110,6 +124,10 @@ def __call__(
if images is not None:
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)
if not legacy:
data["input_ids"] = data["input_ids"][:, :-1]
data["attention_mask"] = data["attention_mask"][:, :-1]

return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/kosmos2/processing_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,21 @@ def post_process_generation(self, text, cleanup_and_extract=True):
return clean_text_and_extract_entities_with_bboxes(caption)
return caption

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
generated_texts = self.batch_decode(generated_outputs, skip_special_tokens=True)
return [self.post_process_generation(text, cleanup_and_extract=False) for text in generated_texts]

@property
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
def model_input_names(self):
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/mllama/processing_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,22 @@ def decode(self, *args, **kwargs):
"""
return self.tokenizer.decode(*args, **kwargs)

def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.

Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.

Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/models/pix2struct/processing_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
Processor class for Pix2Struct.
"""

import logging
from typing import List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils.deprecation import deprecate_kwarg


class Pix2StructImagesKwargs(ImagesKwargs, total=False):
Expand Down Expand Up @@ -48,6 +50,9 @@ class Pix2StructProcessorKwargs(ProcessingKwargs, total=False):
}


logger = logging.getLogger(__name__)


class Pix2StructProcessor(ProcessorMixin):
r"""
Constructs a PIX2STRUCT processor which wraps a BERT tokenizer and PIX2STRUCT image processor into a single
Expand All @@ -71,6 +76,7 @@ def __init__(self, image_processor, tokenizer):
tokenizer.return_token_type_ids = False
super().__init__(image_processor, tokenizer)

@deprecate_kwarg(old_name="legacy", version="5.0.0")
def __call__(
self,
images=None,
Expand All @@ -85,6 +91,14 @@ def __call__(

Please refer to the docstring of the above two methods for more information.
"""
legacy = kwargs.pop("legacy", True)
if legacy:
logger.warning(
"Legacy behavior is being used. The new behavior with legacy=False will be enabled in the future."
"In the new behavior, If both images and text are provided, image_processor is not a VQA processor, and `add_special_tokens` is unset, "
"the default value of `add_special_tokens` will be changed to `False` when calling the tokenizer."
)

if images is None and text is None:
raise ValueError("You have to specify either images or text.")

Expand All @@ -93,8 +107,12 @@ def __call__(
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
add_special_tokens = output_kwargs["text_kwargs"].pop("add_special_tokens", None)
# Get only text
if images is None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else True
)
self.current_processor = self.tokenizer
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
return text_encoding
Expand All @@ -108,6 +126,9 @@ def __call__(
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])

if text is not None and not self.image_processor.is_vqa:
output_kwargs["text_kwargs"]["add_special_tokens"] = (
add_special_tokens if add_special_tokens is not None else legacy
)
text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])

if "attention_mask" in text_encoding:
Expand Down
Loading
Loading