Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniformize kwargs for processors - GroundingDINO #31964

Merged
merged 23 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1580,7 +1580,7 @@ def _set_gradient_checkpointing(self, module, value=False):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.

Indices can be obtained using [`AutoTokenizer`]. See [`GroundingDinoTokenizer.__call__`] for details.
Indices can be obtained using [`AutoTokenizer`]. See [`BertTokenizer.__call__`] for details.

token_type_ids (`torch.LongTensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
Expand Down
89 changes: 53 additions & 36 deletions src/transformers/models/grounding_dino/processing_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,32 @@
Processor class for Grounding DINO.
"""

from typing import List, Optional, Tuple, Union
import pathlib
import sys
from typing import Dict, List, Optional, Tuple, Union

from ...image_processing_utils import BatchFeature
from ...image_transforms import center_to_corners_format
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...image_utils import AnnotationFormat, ImageInput
from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin


if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack

from ...tokenization_utils_base import BatchEncoding, PreTokenizedInput, TextInput
from ...utils import TensorType, is_torch_available


if is_torch_available():
import torch


AnnotationType = Dict[str, Union[int, str, List[Dict]]]


def get_phrases_from_posmap(posmaps, input_ids):
"""Get token ids of phrases from posmaps and input_ids.

Expand All @@ -56,6 +68,31 @@ def get_phrases_from_posmap(posmaps, input_ids):
return token_ids


class GroundingDinoImagesKwargs(ImagesKwargs, total=False):
annotations: Optional[Union[AnnotationType, List[AnnotationType]]]
return_segmentation_masks: Optional[bool]
masks_path: Optional[Union[str, pathlib.Path]]
do_convert_annotations: Optional[bool]
format: Optional[Union[str, AnnotationFormat]]


class GroundingDinoProcessorKwargs(ProcessingKwargs, total=False):
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
images_kwargs: GroundingDinoImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
"padding": False,
"stride": 0,
"return_overflowing_tokens": False,
"return_special_tokens_mask": False,
"return_offsets_mapping": False,
"return_token_type_ids": True,
"return_length": False,
"verbose": True,
}
}


class GroundingDinoProcessor(ProcessorMixin):
r"""
Constructs a Grounding DINO processor which wraps a Deformable DETR image processor and a BERT tokenizer into a
Expand Down Expand Up @@ -83,21 +120,9 @@ def __call__(
self,
images: ImageInput = None,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_offsets_mapping: bool = False,
return_token_type_ids: bool = True,
return_length: bool = False,
verbose: bool = True,
return_tensors: Optional[Union[str, TensorType]] = None,
**kwargs,
audio=None,
videos=None,
**kwargs: Unpack[GroundingDinoProcessorKwargs],
) -> BatchEncoding:
"""
This method uses [`GroundingDinoImageProcessor.__call__`] method to prepare image(s) for the model, and
Expand All @@ -106,32 +131,24 @@ def __call__(
Please refer to the docstring of the above two methods for more information.
"""
if images is None and text is None:
raise ValueError("You have to specify either images or text.")
raise ValueError("You must specify either text or images.")

output_kwargs = self._merge_kwargs(
GroundingDinoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

# Get only text
if images is not None:
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
encoding_image_processor = self.image_processor(images, **output_kwargs["images_kwargs"])
else:
encoding_image_processor = BatchFeature()

if text is not None:
text_encoding = self.tokenizer(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_offsets_mapping=return_offsets_mapping,
return_token_type_ids=return_token_type_ids,
return_length=return_length,
verbose=verbose,
return_tensors=return_tensors,
**kwargs,
**output_kwargs["text_kwargs"],
)
else:
text_encoding = BatchEncoding()
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/processing_utils.py
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -736,12 +736,12 @@ def _merge_kwargs(
The order of operations is as follows:
1) kwargs passed as before have highest priority to preserve BC.
```python
high_priority_kwargs = {"crop_size" = (224, 224), "padding" = "max_length"}
high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
processor(..., **high_priority_kwargs)
```
2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
```python
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": (224, 224)}})
processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
```
3) kwargs passed during instantiation of a modality processor have fourth priority.
```python
Expand Down
188 changes: 187 additions & 1 deletion tests/models/grounding_dino/test_processor_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import IMAGE_PROCESSOR_NAME, is_torch_available, is_vision_available

from ...test_processing_common import ProcessorTesterMixin


if is_torch_available():
import torch
Expand All @@ -40,7 +42,10 @@

@require_torch
@require_vision
class GroundingDinoProcessorTest(unittest.TestCase):
class GroundingDinoProcessorTest(ProcessorTesterMixin, unittest.TestCase):
from_pretrained_id = "IDEA-Research/grounding-dino-base"
processor_class = GroundingDinoProcessor

def setUp(self):
self.tmpdirname = tempfile.mkdtemp()

Expand All @@ -63,6 +68,13 @@ def setUp(self):
with open(self.image_processor_file, "w", encoding="utf-8") as fp:
json.dump(image_processor_map, fp)

image_processor = GroundingDinoImageProcessor()
tokenizer = BertTokenizer.from_pretrained(self.from_pretrained_id)

processor = GroundingDinoProcessor(image_processor, tokenizer)

processor.save_pretrained(self.tmpdirname)

self.batch_size = 7
self.num_queries = 5
self.embed_dim = 5
Expand Down Expand Up @@ -251,3 +263,177 @@ def test_model_input_names(self):
inputs = processor(text=input_str, images=image_input)

self.assertListEqual(list(inputs.keys()), processor.model_input_names)

@require_torch
@require_vision
def test_image_processor_defaults_preserved_by_image_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size={"height": 234, "width": 234})
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input)
self.assertEqual(len(inputs["pixel_values"][0][0]), 234)

@require_vision
@require_torch
def test_kwargs_overrides_default_tokenizer_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(
text=input_str, images=image_input, return_tensors="pt", padding="max_length", max_length=112
)
self.assertEqual(len(inputs["input_ids"][0]), 112)

@require_vision
@require_torch
def test_tokenizer_defaults_preserved_by_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, return_tensors="pt", padding="max_length")
self.assertEqual(len(inputs["input_ids"][0]), 117)

@require_torch
@require_vision
def test_kwargs_overrides_default_image_processor_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor", size=(234, 234))
tokenizer = self.get_component("tokenizer", max_length=117)

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

inputs = processor(text=input_str, images=image_input, size=[224, 224])
SangbumChoi marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(len(inputs["pixel_values"][0][0]), 224)

@require_torch
@require_vision
def test_structured_kwargs_nested(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.skip_processor_without_typed_kwargs(processor)

self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_structured_kwargs_nested_from_dict(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")

image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)
input_str = "lower newer"
image_input = self.prepare_image_inputs()

# Define the kwargs for each modality
all_kwargs = {
"common_kwargs": {"return_tensors": "pt"},
"images_kwargs": {"size": {"height": 214, "width": 214}},
"text_kwargs": {"padding": "max_length", "max_length": 76},
}

inputs = processor(text=input_str, images=image_input, **all_kwargs)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_unstructured_kwargs(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")

processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = "lower newer"
image_input = self.prepare_image_inputs()
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
size={"height": 214, "width": 214},
padding="max_length",
max_length=76,
)

self.assertEqual(inputs["pixel_values"].shape[2], 214)
self.assertEqual(len(inputs["input_ids"][0]), 76)

@require_torch
@require_vision
def test_unstructured_kwargs_batched(self):
if "image_processor" not in self.processor_class.attributes:
self.skipTest(f"image_processor attribute not present in {self.processor_class}")
image_processor = self.get_component("image_processor")
tokenizer = self.get_component("tokenizer")
if not tokenizer.pad_token:
tokenizer.pad_token = "[TEST_PAD]"
processor = self.processor_class(tokenizer=tokenizer, image_processor=image_processor)
self.skip_processor_without_typed_kwargs(processor)

input_str = ["lower newer", "upper older longer string"]
image_input = self.prepare_image_inputs() * 2
inputs = processor(
text=input_str,
images=image_input,
return_tensors="pt",
crop_size={"height": 214, "width": 214},
size={"height": 214, "width": 214},
padding="longest",
max_length=76,
)
self.assertEqual(inputs["pixel_values"].shape[2], 214)

self.assertEqual(len(inputs["input_ids"][0]), 6)
Loading