Skip to content

Commit 030027d

Browse files
authored
Merge pull request #2 from RUFFY-369/imagebind_hf
Imagebind hf changes
2 parents 8d717d0 + e2f3064 commit 030027d

File tree

5 files changed

+571
-26
lines changed

5 files changed

+571
-26
lines changed

docs/source/en/model_doc/imagebind.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ The abstract from the paper is the following:
2222

2323
*We present ImageBind, an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. We show that all combinations of paired data are not necessary to train such a joint embedding, and only image-paired data is sufficient to bind the modalities together. ImageBind can leverage recent large scale vision-language models, and extends their zero-shot capabilities to new modalities just by using their natural pairing with images. It enables novel emergent applications 'out-of-the-box' including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. The emergent capabilities improve with the strength of the image encoder and we set a new state-of-the-art on emergent zero-shot recognition tasks across modalities, outperforming specialist supervised models. Finally, we show strong few-shot recognition results outperforming prior work, and that ImageBind serves as a new way to evaluate vision models for visual and non-visual tasks.*
2424

25-
This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
25+
This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [ruffy369](https://huggingface.co/ruffy369) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
2626
The original code can be found [here](https://github.com/facebookresearch/ImageBind).
2727

2828
## Usage tips

src/transformers/models/imagebind/image_processing_imagebind.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def __init__(
365365
self.fps = fps
366366
self._valid_processor_keys = [
367367
"images",
368+
"videos",
368369
"do_resize",
369370
"size",
370371
"resample",
@@ -379,6 +380,7 @@ def __init__(
379380
"do_chunk",
380381
"chunk_duration",
381382
"num_chunks",
383+
"num_frames_per_chunk",
382384
"fps",
383385
"return_tensors",
384386
"data_format",

src/transformers/models/imagebind/processing_imagebind.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,38 @@
1515
Image/Text processor class for ImageBind
1616
"""
1717

18-
from ...processing_utils import ProcessingKwargs, ProcessorMixin
19-
from ...tokenization_utils_base import BatchEncoding
20-
18+
from typing import List, Optional, Union
19+
20+
try:
21+
from typing import Unpack
22+
except ImportError:
23+
from typing_extensions import Unpack
24+
25+
from ...image_utils import ImageInput
26+
from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin
27+
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput
28+
29+
class ImageBindProcessorImagesKwargs(ImagesKwargs, total=False):
30+
do_convert_rgb: bool = None
31+
do_chunk: bool = None
32+
chunk_duration: float = None
33+
num_chunks: int = None
34+
num_frames_per_chunk: int = None
35+
fps: int = None
36+
37+
class ImageBindProcessorAudioKwargs(AudioKwargs, total=False):
38+
do_normalize: Optional[bool] = None
39+
mean: Optional[float] = None
40+
std: Optional[float] = None
41+
do_chunk: Optional[bool] = None
42+
chunk_duration: Optional[float] = None
43+
num_chunks: Optional[int] = None
2144

2245
class ImageBindProcessorKwargs(ProcessingKwargs, total=False):
2346
# see processing_utils.ProcessingKwargs documentation for usage.
24-
_defaults = {
25-
"text_kwargs": {
26-
"padding": "max_length",
27-
"max_length": 64,
28-
},
29-
}
47+
images_kwargs: ImageBindProcessorImagesKwargs
48+
audio_kwargs: ImageBindProcessorAudioKwargs
49+
_defaults = {}
3050

3151

3252
class ImageBindProcessor(ProcessorMixin):
@@ -53,23 +73,29 @@ class ImageBindProcessor(ProcessorMixin):
5373
def __init__(self, image_processor, tokenizer, feature_extractor):
5474
super().__init__(image_processor, tokenizer, feature_extractor)
5575

56-
def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kwargs):
76+
def __call__(
77+
self,
78+
images=None,
79+
text=None,
80+
audio=None,
81+
**kwargs: Unpack[ImageBindProcessorKwargs],
82+
) -> BatchEncoding:
5783
"""
5884
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
5985
and `kwargs` arguments to ImageBindTokenizerFast's [`~ImageBindTokenizerFast.__call__`] if `text` is not `None` to encode
6086
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
6187
ImageBindImageProcessor's [`~ImageBindImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
6288
of the above two methods for more information.
6389
Args:
64-
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
90+
images (`ImageInput`, *optional*):
6591
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
6692
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
6793
number of channels, H and W are image height and width.
68-
text (`str`, `List[str]`, `List[List[str]]`):
94+
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
6995
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
7096
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
7197
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
72-
audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`):
98+
audio (`AudioInput`, `List[float]`, `List[List[float]]`, `List[List[List[float]]]`):
7399
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of numpy
74100
arrays or a (possibly nested) list of float values. The supported input types are as follows:
75101
@@ -78,12 +104,6 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
78104
- batched with clips: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), `List[np.ndarray]` (`ndim=2`), np.ndarray (`ndim=3`)
79105
80106
The input will always be interpreted as mono channel audio, not stereo, i.e. a single float per timestep.
81-
return_tensors (`str` or [`~utils.TensorType`], *optional*):
82-
If set, will return tensors of a particular framework. Acceptable values are:
83-
- `'tf'`: Return TensorFlow `tf.constant` objects.
84-
- `'pt'`: Return PyTorch `torch.Tensor` objects.
85-
- `'np'`: Return NumPy `np.ndarray` objects.
86-
- `'jax'`: Return JAX `jnp.ndarray` objects.
87107
Returns:
88108
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
89109
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
@@ -97,21 +117,27 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
97117
if text is None and images is None and audio is None:
98118
raise ValueError("You have to specify either text, images or audio. Both cannot be none.")
99119

120+
output_kwargs = self._merge_kwargs(
121+
ImageBindProcessorKwargs,
122+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
123+
**kwargs,
124+
)
125+
100126
data = {}
101127

102128
if text is not None:
103-
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
129+
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
104130
data.update(encoding)
105131

106132
if images is not None:
107-
image_features = self.image_processor(images, return_tensors=return_tensors)
133+
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
108134
data.update(image_features)
109135

110136
if audio is not None:
111-
audio_features = self.feature_extractor(audio, return_tensors=return_tensors)
137+
audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
112138
data.update(audio_features)
113139

114-
return BatchEncoding(data=data, tensor_type=return_tensors)
140+
return BatchEncoding(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))
115141

116142
def batch_decode(self, *args, **kwargs):
117143
"""

0 commit comments

Comments
 (0)