Skip to content

Commit 8d717d0

Browse files
committed
Video is now matching
1 parent 8bea22a commit 8d717d0

File tree

1 file changed

+236
-9
lines changed

1 file changed

+236
-9
lines changed

src/transformers/models/imagebind/image_processing_imagebind.py

Lines changed: 236 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
"""Image processor class for ImageBind."""
1515

1616
import math
17+
import warnings
1718
from fractions import Fraction
18-
from typing import Dict, List, Optional, Tuple, Union
19+
from typing import Dict, Iterable, List, Optional, Tuple, Union
1920

2021
import numpy as np
2122

@@ -25,6 +26,7 @@
2526
get_resize_output_image_size,
2627
resize,
2728
to_channel_dimension_format,
29+
to_pil_image,
2830
)
2931
from ...image_utils import (
3032
OPENAI_CLIP_MEAN,
@@ -33,6 +35,7 @@
3335
ImageInput,
3436
PILImageResampling,
3537
VideoInput,
38+
get_image_size,
3639
infer_channel_dimension_format,
3740
is_scaled_image,
3841
is_valid_image,
@@ -42,7 +45,7 @@
4245
validate_kwargs,
4346
validate_preprocess_arguments,
4447
)
45-
from ...utils import TensorType, is_vision_available, logging
48+
from ...utils import TensorType, is_torch_available, is_vision_available, logging, requires_backends
4649

4750

4851
logger = logging.get_logger(__name__)
@@ -51,6 +54,9 @@
5154
if is_vision_available():
5255
import PIL
5356

57+
if is_torch_available():
58+
import torch
59+
5460

5561
# Copy from models.video_llava.image_processing_video_llava.make_batched_videos
5662
def make_batched_videos(videos) -> List[VideoInput]:
@@ -119,6 +125,151 @@ def uniform_temporal_subsample(video: VideoInput, num_samples: int) -> VideoInpu
119125
return [video[i] for i in indices]
120126

121127

128+
# Adapted from https://github.com/facebookresearch/pytorchvideo/blob/1fadaef40dd393ca09680f55582399f4679fc9b7/pytorchvideo/transforms/functional.py#L92
129+
def video_resize(
130+
frames: List[np.ndarray],
131+
size: Tuple[int, int] = 224,
132+
resampling: PILImageResampling = PILImageResampling.BILINEAR,
133+
data_format: Optional[Union[str, ChannelDimension]] = None,
134+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
135+
) -> np.ndarray:
136+
"""
137+
Determines the shorter spatial dim of the video (i.e. width or height) and scales
138+
it to the given size. To maintain aspect ratio, the longer side is then scaled
139+
accordingly.
140+
Args:
141+
image (np.ndarray): A video tensor of shape (C, T, H, W) and type numpy.float32.
142+
size (int): The size the shorter side is scaled to.
143+
resample (str): Algorithm used for upsampling,
144+
options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
145+
data_format (`str` or `ChannelDimension`, *optional*):
146+
The channel dimension format of the image. If not provided, it will be the same as the input image.
147+
input_data_format (`ChannelDimension` or `str`, *optional*):
148+
The channel dimension format of the input image. If not provided, it will be inferred.
149+
Returns:
150+
An image-like numpy array with scaled spatial dims.
151+
""" # noqa
152+
requires_backends(video_resize, ["torch"])
153+
154+
# channel-first
155+
frames = [
156+
to_channel_dimension_format(frame, ChannelDimension.FIRST, input_channel_dim=input_data_format)
157+
for frame in frames
158+
]
159+
# stack, to torch and reshape to num_channels, num_frames, height, width
160+
video = np.stack(frames)
161+
video = torch.from_numpy(video).contiguous()
162+
163+
data_format = input_data_format if data_format is None else data_format
164+
video = torch.nn.functional.interpolate(video, size=size, mode=resampling.name.lower(), align_corners=False)
165+
frames = list(video.numpy())
166+
frames = [
167+
to_channel_dimension_format(frame, data_format, input_channel_dim=ChannelDimension.FIRST) for frame in frames
168+
]
169+
170+
return frames
171+
172+
173+
# Same as in image_transformers.py but taking offsets like int(math.ceil((orig_height - crop_height) / 2))
174+
def modified_center_crop(
175+
image: np.ndarray,
176+
size: Tuple[int, int],
177+
data_format: Optional[Union[str, ChannelDimension]] = None,
178+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
179+
return_numpy: Optional[bool] = None,
180+
) -> np.ndarray:
181+
"""
182+
Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
183+
the size given, it will be padded (so the returned result will always be of size `size`).
184+
185+
Args:
186+
image (`np.ndarray`):
187+
The image to crop.
188+
size (`Tuple[int, int]`):
189+
The target size for the cropped image.
190+
data_format (`str` or `ChannelDimension`, *optional*):
191+
The channel dimension format for the output image. Can be one of:
192+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
193+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
194+
If unset, will use the inferred format of the input image.
195+
input_data_format (`str` or `ChannelDimension`, *optional*):
196+
The channel dimension format for the input image. Can be one of:
197+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
198+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
199+
If unset, will use the inferred format of the input image.
200+
return_numpy (`bool`, *optional*):
201+
Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
202+
previous ImageFeatureExtractionMixin method.
203+
- Unset: will return the same type as the input image.
204+
- `True`: will return a numpy array.
205+
- `False`: will return a `PIL.Image.Image` object.
206+
Returns:
207+
`np.ndarray`: The cropped image.
208+
"""
209+
requires_backends(modified_center_crop, ["vision"])
210+
211+
if return_numpy is not None:
212+
warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
213+
214+
return_numpy = True if return_numpy is None else return_numpy
215+
216+
if not isinstance(image, np.ndarray):
217+
raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
218+
219+
if not isinstance(size, Iterable) or len(size) != 2:
220+
raise ValueError("size must have 2 elements representing the height and width of the output image")
221+
222+
if input_data_format is None:
223+
input_data_format = infer_channel_dimension_format(image)
224+
output_data_format = data_format if data_format is not None else input_data_format
225+
226+
# We perform the crop in (C, H, W) format and then convert to the output format
227+
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
228+
229+
orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
230+
crop_height, crop_width = size
231+
crop_height, crop_width = int(crop_height), int(crop_width)
232+
233+
# In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
234+
top = int(math.ceil((orig_height - crop_height) / 2))
235+
bottom = top + crop_height
236+
# In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
237+
left = int(math.ceil((orig_width - crop_width) / 2))
238+
right = left + crop_width
239+
240+
# Check if cropped area is within image boundaries
241+
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
242+
image = image[..., top:bottom, left:right]
243+
image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
244+
return image
245+
246+
# Otherwise, we may need to pad if the image is too small. Oh joy...
247+
new_height = max(crop_height, orig_height)
248+
new_width = max(crop_width, orig_width)
249+
new_shape = image.shape[:-2] + (new_height, new_width)
250+
new_image = np.zeros_like(image, shape=new_shape)
251+
252+
# If the image is too small, pad it with zeros
253+
top_pad = math.ceil((new_height - orig_height) / 2)
254+
bottom_pad = top_pad + orig_height
255+
left_pad = math.ceil((new_width - orig_width) / 2)
256+
right_pad = left_pad + orig_width
257+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
258+
259+
top += top_pad
260+
bottom += top_pad
261+
left += left_pad
262+
right += left_pad
263+
264+
new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
265+
new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
266+
267+
if not return_numpy:
268+
new_image = to_pil_image(new_image)
269+
270+
return new_image
271+
272+
122273
class ImageBindImageProcessor(BaseImageProcessor):
123274
r"""
124275
Constructs an ImageBind image processor.
@@ -242,6 +393,38 @@ def __init__(
242393
# `shortest_edge` key.
243394
delattr(self, "use_square_size")
244395

396+
def video_resize(
397+
self,
398+
frames: List[np.ndarray],
399+
size: Dict[str, int],
400+
resampling: PILImageResampling = PILImageResampling.BILINEAR,
401+
data_format: Optional[Union[str, ChannelDimension]] = None,
402+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
403+
) -> List[np.ndarray]:
404+
default_to_square = True
405+
if "shortest_edge" in size:
406+
size = size["shortest_edge"]
407+
default_to_square = False
408+
elif "height" in size and "width" in size:
409+
size = (size["height"], size["width"])
410+
else:
411+
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
412+
413+
output_size = get_resize_output_image_size(
414+
frames[0],
415+
size=size,
416+
default_to_square=default_to_square,
417+
input_data_format=input_data_format,
418+
)
419+
420+
return video_resize(
421+
frames=frames,
422+
size=output_size,
423+
resampling=resampling,
424+
data_format=data_format,
425+
input_data_format=input_data_format,
426+
)
427+
245428
# Copied from models.clip.image_processing_clip.CLIPImageProcessor.resize
246429
def resize(
247430
self,
@@ -327,10 +510,49 @@ def chunk(
327510

328511
return all_clips
329512

330-
# Copied from models.clip.image_processing_clip.CLIPImageProcessor.preprocess with preprocess->_preprocess_image
513+
def center_crop(
514+
self,
515+
image: np.ndarray,
516+
size: Dict[str, int],
517+
data_format: Optional[Union[str, ChannelDimension]] = None,
518+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
519+
**kwargs,
520+
) -> np.ndarray:
521+
"""
522+
Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
523+
any edge, the image is padded with 0's and then center cropped.
524+
525+
Args:
526+
image (`np.ndarray`):
527+
Image to center crop.
528+
size (`Dict[str, int]`):
529+
Size of the output image.
530+
data_format (`str` or `ChannelDimension`, *optional*):
531+
The channel dimension format for the output image. If unset, the channel dimension format of the input
532+
image is used. Can be one of:
533+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
534+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
535+
input_data_format (`ChannelDimension` or `str`, *optional*):
536+
The channel dimension format for the input image. If unset, the channel dimension format is inferred
537+
from the input image. Can be one of:
538+
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
539+
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
540+
"""
541+
size = get_size_dict(size)
542+
if "height" not in size or "width" not in size:
543+
raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
544+
return modified_center_crop(
545+
image,
546+
size=(size["height"], size["width"]),
547+
data_format=data_format,
548+
input_data_format=input_data_format,
549+
**kwargs,
550+
)
551+
331552
def _preprocess_image(
332553
self,
333554
images: ImageInput,
555+
is_video: bool = False,
334556
do_resize: bool = None,
335557
size: Dict[str, int] = None,
336558
resample: PILImageResampling = None,
@@ -375,10 +597,15 @@ def _preprocess_image(
375597
input_data_format = infer_channel_dimension_format(images[0])
376598

377599
if do_resize:
378-
images = [
379-
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
380-
for image in images
381-
]
600+
if is_video:
601+
images = self.video_resize(
602+
frames=images, size=size, resampling=resample, input_data_format=input_data_format
603+
)
604+
else:
605+
images = [
606+
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
607+
for image in images
608+
]
382609

383610
if do_center_crop:
384611
images = [
@@ -403,7 +630,6 @@ def _preprocess_image(
403630

404631
return images
405632

406-
# Ignore copy
407633
def preprocess(
408634
self,
409635
images: Optional[ImageInput] = None,
@@ -565,6 +791,7 @@ def preprocess(
565791
_pixel_values = [
566792
self._preprocess_image(
567793
images=clip,
794+
is_video=True,
568795
do_resize=do_resize,
569796
size=size,
570797
resample=PILImageResampling.BILINEAR,
@@ -601,7 +828,7 @@ def preprocess(
601828
)
602829
]
603830

604-
# Avoid List[List[List[np.ndarray]]]
831+
# Avoid List[List[List[np.ndarray]]] for performance reasons
605832
_pixel_values = np.stack(_pixel_values)
606833
# Make it shape (num_chunks, num_channels, num_frames_per_chunk, height, width)
607834
_pixel_values = np.swapaxes(_pixel_values, 1, 2)

0 commit comments

Comments
 (0)