Skip to content

Commit f690a2a

Browse files
zucchini-nlpqubvelgithub-actions[bot]
authored
[video processors] decode only sampled videos -> less RAM and faster processing (#39600)
* draft update two models for now * batch update all VLMs first * update some more image processors * update * fix a few tests * just make CI green for now * fix copies * update once more * update * unskip the test * fix these two * fix torchcodec audio loading * maybe * yay, i fixed torchcodec installation and now can actually test it * fix copies deepseek * make sure the metadata is returrned when users request it * add docs * update * fixup * Update src/transformers/audio_utils.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update src/transformers/models/glm4v/video_processing_glm4v.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * update * what if we set some metadata attr to `None` * fix CI * fix one test * fix 4 channel test * fix glm timestemps * rebase gone wrong * raise warning once * fixup * typo * fix copies * ifx smolvlm test * this is why torch's official benchmark was faster, set threads to `0` * Apply style fixes --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 64ae6e6 commit f690a2a

File tree

74 files changed

+753
-605
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+753
-605
lines changed

docs/source/en/main_classes/image_processor.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ rendered properly in your Markdown viewer.
1616

1717
# Image Processor
1818

19-
An image processor is in charge of preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to Numpy and PyTorch tensors. It may also include model specific post-processing such as converting logits to segmentation masks.
20-
19+
An image processor is in charge of loading images (optionally), preparing input features for vision models and post processing their outputs. This includes transformations such as resizing, normalization, and conversion to PyTorch and Numpy tensors. It may also include model specific post-processing such as converting logits to segmentation masks.
2120
Fast image processors are available for a few models and more will be added in the future. They are based on the [torchvision](https://pytorch.org/vision/stable/index.html) library and provide a significant speed-up, especially when processing on GPU.
2221
They have the same API as the base image processors and can be used as drop-in replacements.
2322
To use a fast image processor, you need to install the `torchvision` library, and set the `use_fast` argument to `True` when instantiating the image processor:

docs/source/en/main_classes/video_processor.md

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ rendered properly in your Markdown viewer.
1414
1515
-->
1616

17-
1817
# Video Processor
1918

20-
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch.
19+
A **Video Processor** is a utility responsible for preparing input features for video models, as well as handling the post-processing of their outputs. It provides transformations such as resizing, normalization, and conversion into PyTorch. Along ith transformations the `VideoProcessor` class handles video decoding from local paths or URLs (requires [`torchcodec`](https://pypi.org/project/torchcodec/)) and frame sampling according to model-specific strategies.
2120

2221
The video processor extends the functionality of image processors by allowing Vision Large Language Models (VLMs) to handle videos with a distinct set of arguments compared to images. It serves as the bridge between raw video data and the model, ensuring that input features are optimized for the VLM.
2322

@@ -48,6 +47,47 @@ processor = torch.compile(processor)
4847
processed_video = processor(video, return_tensors="pt")
4948
```
5049

50+
#### Sampling behavior
51+
52+
The video processor can also sample video frames using the technique best suited for the given model. Sampling behavior is controlled with the `do_sample_frames` argument and can be configured through model-specific parameters such as `num_frames` or `fps` (the rate at which the video will be sampled). If the input video is given as a local path or URL (`str`), the processor will decode it automatically. To obtain metadata about the decoded video, such as sampled frame indices, original dimensions, duration, and fps, pass `return_metadata=True` to the processor.
53+
54+
<Tip warning={false}>
55+
56+
- Specifying `num_frames` does not guarantee the output will contain exactly that number of frames. Depending on the model, the sampler may enforce minimum or maximum frame limits.
57+
58+
- The default decoder is [`torchcodec`](https://pypi.org/project/torchcodec/), which must be installed.
59+
60+
</Tip>
61+
62+
63+
```python
64+
from transformers import AutoVideoProcessor
65+
66+
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
67+
processed_video_inputs = processor(videos=["video_path.mp4"], return_metadata=True, do_sample_frames=True, return_tensors="pt")
68+
video_metadata = processed_video_inputs["video_metadata"]
69+
70+
# See how many frames the original video had and what was the original FPS
71+
print(video_metadata.total_num_frames, video_metadata.fps)
72+
```
73+
74+
If you pass an already decoded video array but still want to enable model-specific frame sampling, it is strongly recommended to provide video_metadata. This allows the sampler to know the original video’s duration and FPS. You can pass metadata as a `VideoMetadata` object or as a plain dict.
75+
76+
```python
77+
from transformers import AutoVideoProcessor
78+
from transformers.video_utils import VideoMetadata
79+
80+
processor = AutoVideoProcessor.from_pretrained("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", device="cuda")
81+
my_decodec_video = torch.randint(0, 255, size=(100, 3, 1280, 1280)) # short video of 100 frames
82+
video_metadata = VideoMetadata(
83+
total_num_frames=100,
84+
fps=24,
85+
duration=4.1, # in seconds
86+
)
87+
processed_video_inputs = processor(videos=["video_path.mp4"], video_metadata=video_metadata, do_sample_frames=True, num_frames=10, return_tensors="pt")
88+
print(processed_video_inputs.pixel_values_videos.shape)
89+
>>> [10, 3, 384, 384]
90+
```
5191

5292
## BaseVideoProcessor
5393

src/transformers/audio_utils.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import warnings
2323
from io import BytesIO
24-
from typing import Any, Optional, Union
24+
from typing import Any, Optional, Sequence, Union
2525

2626
import numpy as np
2727
import requests
@@ -31,6 +31,7 @@
3131
is_numpy_array,
3232
is_soundfile_available,
3333
is_torch_tensor,
34+
is_torchcodec_available,
3435
requires_backends,
3536
)
3637

@@ -44,6 +45,12 @@
4445
# TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
4546
import soxr
4647

48+
if is_torchcodec_available():
49+
from torchcodec.decoders import AudioDecoder
50+
51+
52+
AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]] # noqa: F821
53+
4754

4855
def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
4956
"""
@@ -61,14 +68,14 @@ def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None)
6168
Returns:
6269
`np.ndarray`: A numpy array representing the audio.
6370
"""
64-
requires_backends(load_audio, ["librosa"])
65-
6671
if isinstance(audio, str):
67-
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
68-
if audio.startswith("http://") or audio.startswith("https://"):
69-
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
70-
elif os.path.isfile(audio):
71-
audio = librosa.load(audio, sr=sampling_rate)[0]
72+
# Try to load with `torchcodec` but do not enforce users to install it. If not found
73+
# fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
74+
# needed.
75+
if is_torchcodec_available():
76+
audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate)
77+
else:
78+
audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
7279
elif isinstance(audio, np.ndarray):
7380
audio = audio
7481
else:
@@ -78,6 +85,54 @@ def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None)
7885
return audio
7986

8087

88+
def load_audio_torchcodec(audio: Union[str, np.ndarray], sampling_rate=16000) -> np.ndarray:
89+
"""
90+
Loads `audio` to an np.ndarray object using `torchcodec`.
91+
92+
Args:
93+
audio (`str` or `np.ndarray`):
94+
The audio to be loaded to the numpy array format.
95+
sampling_rate (`int`, *optional*, defaults to 16000):
96+
The sampling rate to be used when loading the audio. It should be same as the
97+
sampling rate the model you will be using further was trained with.
98+
99+
Returns:
100+
`np.ndarray`: A numpy array representing the audio.
101+
"""
102+
requires_backends(load_audio, ["torchcodec"])
103+
104+
# Set `num_channels` to `1` which is what most models expects and the default in librosa
105+
decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
106+
audio = decoder.get_all_samples().data[0].numpy() # NOTE: feature extractors don't accept torch tensors
107+
return audio
108+
109+
110+
def load_audio_librosa(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
111+
"""
112+
Loads `audio` to an np.ndarray object using `librosa`.
113+
114+
Args:
115+
audio (`str` or `np.ndarray`):
116+
The audio to be loaded to the numpy array format.
117+
sampling_rate (`int`, *optional*, defaults to 16000):
118+
The sampling rate to be used when loading the audio. It should be same as the
119+
sampling rate the model you will be using further was trained with.
120+
timeout (`float`, *optional*):
121+
The timeout value in seconds for the URL request.
122+
123+
Returns:
124+
`np.ndarray`: A numpy array representing the audio.
125+
"""
126+
requires_backends(load_audio, ["librosa"])
127+
128+
# Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
129+
if audio.startswith("http://") or audio.startswith("https://"):
130+
audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
131+
elif os.path.isfile(audio):
132+
audio = librosa.load(audio, sr=sampling_rate)[0]
133+
return audio
134+
135+
81136
def load_audio_as(
82137
audio: str,
83138
return_format: str,
@@ -157,11 +212,6 @@ def load_audio_as(
157212
raise ValueError(f"Error loading audio: {e}")
158213

159214

160-
AudioInput = Union[
161-
np.ndarray, "torch.Tensor", list[np.ndarray], tuple[np.ndarray], list["torch.Tensor"], tuple["torch.Tensor"] # noqa: F821
162-
]
163-
164-
165215
def is_valid_audio(audio):
166216
return is_numpy_array(audio) or is_torch_tensor(audio)
167217

src/transformers/image_processing_base.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
import json
1818
import os
1919
import warnings
20-
from io import BytesIO
2120
from typing import Any, Optional, TypeVar, Union
2221

2322
import numpy as np
24-
import requests
2523

2624
from .dynamic_module_utils import custom_object_save
2725
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
26+
from .image_utils import is_valid_image, load_image
2827
from .utils import (
2928
IMAGE_PROCESSOR_NAME,
3029
PushToHubMixin,
@@ -33,15 +32,10 @@
3332
download_url,
3433
is_offline_mode,
3534
is_remote_url,
36-
is_vision_available,
3735
logging,
3836
)
3937

4038

41-
if is_vision_available():
42-
from PIL import Image
43-
44-
4539
ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
4640

4741

@@ -514,25 +508,19 @@ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
514508

515509
cls._auto_class = auto_class
516510

517-
def fetch_images(self, image_url_or_urls: Union[str, list[str]]):
511+
def fetch_images(self, image_url_or_urls: Union[str, list[str], list[list[str]]]):
518512
"""
519513
Convert a single or a list of urls into the corresponding `PIL.Image` objects.
520514
521515
If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
522516
returned.
523517
"""
524-
headers = {
525-
"User-Agent": (
526-
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
527-
" Safari/537.36"
528-
)
529-
}
530518
if isinstance(image_url_or_urls, list):
531519
return [self.fetch_images(x) for x in image_url_or_urls]
532520
elif isinstance(image_url_or_urls, str):
533-
response = requests.get(image_url_or_urls, stream=True, headers=headers)
534-
response.raise_for_status()
535-
return Image.open(BytesIO(response.content))
521+
return load_image(image_url_or_urls)
522+
elif is_valid_image(image_url_or_urls):
523+
return image_url_or_urls
536524
else:
537525
raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
538526

src/transformers/image_processing_utils_fast.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def validate_fast_preprocess_arguments(
8585
crop_size: Optional[SizeDict] = None,
8686
do_resize: Optional[bool] = None,
8787
size: Optional[SizeDict] = None,
88-
resample: Optional["PILImageResampling"] = None,
88+
interpolation: Optional["F.InterpolationMode"] = None,
8989
return_tensors: Optional[Union[str, TensorType]] = None,
9090
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
9191
):
@@ -105,7 +105,7 @@ def validate_fast_preprocess_arguments(
105105
crop_size=crop_size,
106106
do_resize=do_resize,
107107
size=size,
108-
resample=resample,
108+
interpolation=interpolation,
109109
)
110110
# Extra checks for ImageProcessorFast
111111
if return_tensors is not None and return_tensors != "pt":
@@ -469,6 +469,8 @@ def _prepare_images_structure(
469469
Returns:
470470
`ImageInput`: The images with a valid nesting.
471471
"""
472+
# Checks for `str` in case of URL/local path and optionally loads images
473+
images = self.fetch_images(images)
472474
return make_flat_list_of_images(images, expected_ndims=expected_ndims)
473475

474476
def _process_image(
@@ -582,11 +584,19 @@ def _further_process_kwargs(
582584

583585
kwargs["size"] = size
584586
kwargs["crop_size"] = crop_size
585-
kwargs["default_to_square"] = default_to_square
586587
kwargs["image_mean"] = image_mean
587588
kwargs["image_std"] = image_std
588589
kwargs["data_format"] = data_format
589590

591+
# torch resize uses interpolation instead of resample
592+
# Check if resample is an int before checking if it's an instance of PILImageResampling
593+
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
594+
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
595+
resample = kwargs.pop("resample")
596+
kwargs["interpolation"] = (
597+
pil_torch_interpolation_mapping[resample] if isinstance(resample, (PILImageResampling, int)) else resample
598+
)
599+
590600
return kwargs
591601

592602
def _validate_preprocess_kwargs(
@@ -600,7 +610,7 @@ def _validate_preprocess_kwargs(
600610
size: Optional[SizeDict] = None,
601611
do_center_crop: Optional[bool] = None,
602612
crop_size: Optional[SizeDict] = None,
603-
resample: Optional[Union["PILImageResampling", "F.InterpolationMode"]] = None,
613+
interpolation: Optional["F.InterpolationMode"] = None,
604614
return_tensors: Optional[Union[str, TensorType]] = None,
605615
data_format: Optional[ChannelDimension] = None,
606616
**kwargs,
@@ -618,7 +628,7 @@ def _validate_preprocess_kwargs(
618628
size=size,
619629
do_center_crop=do_center_crop,
620630
crop_size=crop_size,
621-
resample=resample,
631+
interpolation=interpolation,
622632
return_tensors=return_tensors,
623633
data_format=data_format,
624634
)
@@ -646,18 +656,7 @@ def preprocess(self, images: ImageInput, *args, **kwargs: Unpack[DefaultFastImag
646656
# Validate kwargs
647657
self._validate_preprocess_kwargs(**kwargs)
648658

649-
# torch resize uses interpolation instead of resample
650-
resample = kwargs.pop("resample")
651-
652-
# Check if resample is an int before checking if it's an instance of PILImageResampling
653-
# because if pillow < 9.1.0, resample is an int and PILImageResampling is a module.
654-
# Checking PILImageResampling will fail with error `TypeError: isinstance() arg 2 must be a type or tuple of types`.
655-
kwargs["interpolation"] = (
656-
pil_torch_interpolation_mapping[resample] if isinstance(resample, (int, PILImageResampling)) else resample
657-
)
658-
659659
# Pop kwargs that are not needed in _preprocess
660-
kwargs.pop("default_to_square")
661660
kwargs.pop("data_format")
662661

663662
return self._preprocess_image_like_inputs(

src/transformers/image_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def validate_preprocess_arguments(
535535
do_resize: Optional[bool] = None,
536536
size: Optional[dict[str, int]] = None,
537537
resample: Optional["PILImageResampling"] = None,
538+
interpolation: Optional["InterpolationMode"] = None,
538539
):
539540
"""
540541
Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
@@ -559,8 +560,13 @@ def validate_preprocess_arguments(
559560
if do_center_crop and crop_size is None:
560561
raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
561562

562-
if do_resize and (size is None or resample is None):
563-
raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
563+
if interpolation is not None and resample is not None:
564+
raise ValueError(
565+
"Only one of `interpolation` and `resample` should be specified, depending on image processor type."
566+
)
567+
568+
if do_resize and not (size is not None and (resample is not None or interpolation is not None)):
569+
raise ValueError("`size` and `resample/interpolation` must be specified if `do_resize` is `True`.")
564570

565571

566572
# In the future we can add a TF implementation here when we have TF models.

src/transformers/models/aria/image_processing_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def preprocess(
228228
if max_image_size not in [490, 980]:
229229
raise ValueError("max_image_size must be either 490 or 980")
230230

231+
images = self.fetch_images(images)
231232
images = make_flat_list_of_images(images)
232233

233234
if not valid_images(images):

src/transformers/models/aria/modular_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def preprocess(
614614
if max_image_size not in [490, 980]:
615615
raise ValueError("max_image_size must be either 490 or 980")
616616

617+
images = self.fetch_images(images)
617618
images = make_flat_list_of_images(images)
618619

619620
if not valid_images(images):

src/transformers/models/aya_vision/processing_aya_vision.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def __call__(
189189
# Process images
190190
image_inputs = {}
191191
if images is not None:
192+
images = self.image_processor.fetch_images(images)
192193
images = make_flat_list_of_images(images)
193194
image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
194195
num_patches = image_inputs.pop("num_patches")

src/transformers/models/blip/image_processing_blip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def preprocess(
231231

232232
size = size if size is not None else self.size
233233
size = get_size_dict(size, default_to_square=False)
234+
images = self.fetch_images(images)
234235
images = make_flat_list_of_images(images)
235236

236237
if not valid_images(images):

0 commit comments

Comments
 (0)