Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a9a5539
chore:add func and classes to get vid clips from user given paths
RUFFY-369 Aug 4, 2024
d1c33d0
chore:update uniform_chunk_sampling()
RUFFY-369 Aug 4, 2024
53fe080
chore:change chunk duration val and type
RUFFY-369 Aug 4, 2024
99306ab
chore:update uniform_temporal_subsample()
RUFFY-369 Aug 4, 2024
082be8b
chore:update video transforms and few nits
RUFFY-369 Aug 4, 2024
1d6c4ea
fix:bug in image processor call on video paths
RUFFY-369 Aug 4, 2024
64d6c38
chore:revert to original to test for unmatched outputs
RUFFY-369 Aug 6, 2024
558f544
chore:make transformers compliant and few nits
RUFFY-369 Aug 7, 2024
9314a57
style:make fixup
RUFFY-369 Aug 7, 2024
79c4089
fix:make fix copies
RUFFY-369 Aug 7, 2024
f64778d
chore:resolve necessary conflicts
RUFFY-369 Aug 7, 2024
02cb2ab
Merge remote-tracking branch 'imagebind/adding-imagebind' into imageb…
RUFFY-369 Aug 24, 2024
4d0edbf
resolve merge/change conflicts by pull
RUFFY-369 Aug 24, 2024
bc8821f
chore:make everything similar about files
RUFFY-369 Aug 24, 2024
fbbb108
test:add image processor tests
RUFFY-369 Aug 26, 2024
4099c8c
fix:failing image processor tests
RUFFY-369 Aug 26, 2024
2d4cb59
chore:add contributor name for video output matching and image proces…
RUFFY-369 Aug 26, 2024
a283626
test:add Processor kwargs and its test
RUFFY-369 Aug 27, 2024
04a9e07
fix:ProcessorTesterMixin test failures
RUFFY-369 Aug 27, 2024
4b7f5a8
fix:test failure for len of input ids
RUFFY-369 Aug 27, 2024
e2f3064
chore:add custom image and audio kwargs class and some nits
RUFFY-369 Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/imagebind.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The abstract from the paper is the following:

*We present ImageBind, an approach to learn a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. We show that all combinations of paired data are not necessary to train such a joint embedding, and only image-paired data is sufficient to bind the modalities together. ImageBind can leverage recent large scale vision-language models, and extends their zero-shot capabilities to new modalities just by using their natural pairing with images. It enables novel emergent applications 'out-of-the-box' including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. The emergent capabilities improve with the strength of the image encoder and we set a new state-of-the-art on emergent zero-shot recognition tasks across modalities, outperforming specialist supervised models. Finally, we show strong few-shot recognition results outperforming prior work, and that ImageBind serves as a new way to evaluate vision models for visual and non-visual tasks.*

This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
This model was contributed by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [ruffy369](https://huggingface.co/ruffy369) and [dg845](https://huggingface.co/dg845) and [shehan97](https://huggingface.co/shehan97).
The original code can be found [here](https://github.com/facebookresearch/ImageBind).

## Usage tips
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
self.fps = fps
self._valid_processor_keys = [
"images",
"videos",
"do_resize",
"size",
"resample",
Expand All @@ -379,6 +380,7 @@ def __init__(
"do_chunk",
"chunk_duration",
"num_chunks",
"num_frames_per_chunk",
"fps",
"return_tensors",
"data_format",
Expand Down
72 changes: 49 additions & 23 deletions src/transformers/models/imagebind/processing_imagebind.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,38 @@
Image/Text processor class for ImageBind
"""

from ...processing_utils import ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import BatchEncoding

from typing import List, Optional, Union

try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack

from ...image_utils import ImageInput
from ...processing_utils import AudioKwargs, ImagesKwargs, ProcessingKwargs, ProcessorMixin
from ...tokenization_utils_base import AudioInput, BatchEncoding, PreTokenizedInput, TextInput

class ImageBindProcessorImagesKwargs(ImagesKwargs, total=False):
do_convert_rgb: bool = None
do_chunk: bool = None
chunk_duration: float = None
num_chunks: int = None
num_frames_per_chunk: int = None
fps: int = None

class ImageBindProcessorAudioKwargs(AudioKwargs, total=False):
do_normalize: Optional[bool] = None
mean: Optional[float] = None
std: Optional[float] = None
do_chunk: Optional[bool] = None
chunk_duration: Optional[float] = None
num_chunks: Optional[int] = None

class ImageBindProcessorKwargs(ProcessingKwargs, total=False):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to create custom AudioKwargs and ImageKwargs as we have arguments that are not in the default classes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, need to discuss the custom kwargs so left them out. i will push those changes

# see processing_utils.ProcessingKwargs documentation for usage.
_defaults = {
"text_kwargs": {
"padding": "max_length",
"max_length": 64,
},
}
images_kwargs: ImageBindProcessorImagesKwargs
audio_kwargs: ImageBindProcessorAudioKwargs
_defaults = {}


class ImageBindProcessor(ProcessorMixin):
Expand All @@ -53,23 +73,29 @@ class ImageBindProcessor(ProcessorMixin):
def __init__(self, image_processor, tokenizer, feature_extractor):
super().__init__(image_processor, tokenizer, feature_extractor)

def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kwargs):
def __call__(
self,
images=None,
text=None,
audio=None,
**kwargs: Unpack[ImageBindProcessorKwargs],
) -> BatchEncoding:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to ImageBindTokenizerFast's [`~ImageBindTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
ImageBindImageProcessor's [`~ImageBindImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
images (`ImageInput`, *optional*):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
text (`str`, `List[str]`, `List[List[str]]`):
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`, *optional*):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
audio (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`, `List[List[List[float]]]`):
audio (`AudioInput`, `List[float]`, `List[List[float]]`, `List[List[List[float]]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of numpy
arrays or a (possibly nested) list of float values. The supported input types are as follows:

Expand All @@ -78,12 +104,6 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
- batched with clips: `List[List[List[float]]]`, `List[List[np.ndarray]]` (`ndim=1`), `List[np.ndarray]` (`ndim=2`), np.ndarray (`ndim=3`)

The input will always be interpreted as mono channel audio, not stereo, i.e. a single float per timestep.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
Expand All @@ -97,21 +117,27 @@ def __call__(self, images=None, text=None, audio=None, return_tensors=None, **kw
if text is None and images is None and audio is None:
raise ValueError("You have to specify either text, images or audio. Both cannot be none.")

output_kwargs = self._merge_kwargs(
ImageBindProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)

data = {}

if text is not None:
encoding = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
encoding = self.tokenizer(text, **output_kwargs["text_kwargs"])
data.update(encoding)

if images is not None:
image_features = self.image_processor(images, return_tensors=return_tensors)
image_features = self.image_processor(images, **output_kwargs["images_kwargs"])
data.update(image_features)

if audio is not None:
audio_features = self.feature_extractor(audio, return_tensors=return_tensors)
audio_features = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
data.update(audio_features)

return BatchEncoding(data=data, tensor_type=return_tensors)
return BatchEncoding(data=data, tensor_type=output_kwargs["common_kwargs"].get("return_tensors"))

def batch_decode(self, *args, **kwargs):
"""
Expand Down
Loading