1414"""Image processor class for ImageBind."""
1515
1616import math
17+ import warnings
1718from fractions import Fraction
18- from typing import Dict , List , Optional , Tuple , Union
19+ from typing import Dict , Iterable , List , Optional , Tuple , Union
1920
2021import numpy as np
2122
2526 get_resize_output_image_size ,
2627 resize ,
2728 to_channel_dimension_format ,
29+ to_pil_image ,
2830)
2931from ...image_utils import (
3032 OPENAI_CLIP_MEAN ,
3335 ImageInput ,
3436 PILImageResampling ,
3537 VideoInput ,
38+ get_image_size ,
3639 infer_channel_dimension_format ,
3740 is_scaled_image ,
3841 is_valid_image ,
4245 validate_kwargs ,
4346 validate_preprocess_arguments ,
4447)
45- from ...utils import TensorType , is_vision_available , logging
48+ from ...utils import TensorType , is_torch_available , is_vision_available , logging , requires_backends
4649
4750
4851logger = logging .get_logger (__name__ )
5154if is_vision_available ():
5255 import PIL
5356
57+ if is_torch_available ():
58+ import torch
59+
5460
5561# Copy from models.video_llava.image_processing_video_llava.make_batched_videos
5662def make_batched_videos (videos ) -> List [VideoInput ]:
@@ -119,6 +125,151 @@ def uniform_temporal_subsample(video: VideoInput, num_samples: int) -> VideoInpu
119125 return [video [i ] for i in indices ]
120126
121127
128+ # Adapted from https://github.com/facebookresearch/pytorchvideo/blob/1fadaef40dd393ca09680f55582399f4679fc9b7/pytorchvideo/transforms/functional.py#L92
129+ def video_resize (
130+ frames : List [np .ndarray ],
131+ size : Tuple [int , int ] = 224 ,
132+ resampling : PILImageResampling = PILImageResampling .BILINEAR ,
133+ data_format : Optional [Union [str , ChannelDimension ]] = None ,
134+ input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
135+ ) -> np .ndarray :
136+ """
137+ Determines the shorter spatial dim of the video (i.e. width or height) and scales
138+ it to the given size. To maintain aspect ratio, the longer side is then scaled
139+ accordingly.
140+ Args:
141+ image (np.ndarray): A video tensor of shape (C, T, H, W) and type numpy.float32.
142+ size (int): The size the shorter side is scaled to.
143+ resample (str): Algorithm used for upsampling,
144+ options: nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'
145+ data_format (`str` or `ChannelDimension`, *optional*):
146+ The channel dimension format of the image. If not provided, it will be the same as the input image.
147+ input_data_format (`ChannelDimension` or `str`, *optional*):
148+ The channel dimension format of the input image. If not provided, it will be inferred.
149+ Returns:
150+ An image-like numpy array with scaled spatial dims.
151+ """ # noqa
152+ requires_backends (video_resize , ["torch" ])
153+
154+ # channel-first
155+ frames = [
156+ to_channel_dimension_format (frame , ChannelDimension .FIRST , input_channel_dim = input_data_format )
157+ for frame in frames
158+ ]
159+ # stack, to torch and reshape to num_channels, num_frames, height, width
160+ video = np .stack (frames )
161+ video = torch .from_numpy (video ).contiguous ()
162+
163+ data_format = input_data_format if data_format is None else data_format
164+ video = torch .nn .functional .interpolate (video , size = size , mode = resampling .name .lower (), align_corners = False )
165+ frames = list (video .numpy ())
166+ frames = [
167+ to_channel_dimension_format (frame , data_format , input_channel_dim = ChannelDimension .FIRST ) for frame in frames
168+ ]
169+
170+ return frames
171+
172+
173+ # Same as in image_transformers.py but taking offsets like int(math.ceil((orig_height - crop_height) / 2))
174+ def modified_center_crop (
175+ image : np .ndarray ,
176+ size : Tuple [int , int ],
177+ data_format : Optional [Union [str , ChannelDimension ]] = None ,
178+ input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
179+ return_numpy : Optional [bool ] = None ,
180+ ) -> np .ndarray :
181+ """
182+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
183+ the size given, it will be padded (so the returned result will always be of size `size`).
184+
185+ Args:
186+ image (`np.ndarray`):
187+ The image to crop.
188+ size (`Tuple[int, int]`):
189+ The target size for the cropped image.
190+ data_format (`str` or `ChannelDimension`, *optional*):
191+ The channel dimension format for the output image. Can be one of:
192+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
193+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
194+ If unset, will use the inferred format of the input image.
195+ input_data_format (`str` or `ChannelDimension`, *optional*):
196+ The channel dimension format for the input image. Can be one of:
197+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
198+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
199+ If unset, will use the inferred format of the input image.
200+ return_numpy (`bool`, *optional*):
201+ Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
202+ previous ImageFeatureExtractionMixin method.
203+ - Unset: will return the same type as the input image.
204+ - `True`: will return a numpy array.
205+ - `False`: will return a `PIL.Image.Image` object.
206+ Returns:
207+ `np.ndarray`: The cropped image.
208+ """
209+ requires_backends (modified_center_crop , ["vision" ])
210+
211+ if return_numpy is not None :
212+ warnings .warn ("return_numpy is deprecated and will be removed in v.4.33" , FutureWarning )
213+
214+ return_numpy = True if return_numpy is None else return_numpy
215+
216+ if not isinstance (image , np .ndarray ):
217+ raise TypeError (f"Input image must be of type np.ndarray, got { type (image )} " )
218+
219+ if not isinstance (size , Iterable ) or len (size ) != 2 :
220+ raise ValueError ("size must have 2 elements representing the height and width of the output image" )
221+
222+ if input_data_format is None :
223+ input_data_format = infer_channel_dimension_format (image )
224+ output_data_format = data_format if data_format is not None else input_data_format
225+
226+ # We perform the crop in (C, H, W) format and then convert to the output format
227+ image = to_channel_dimension_format (image , ChannelDimension .FIRST , input_data_format )
228+
229+ orig_height , orig_width = get_image_size (image , ChannelDimension .FIRST )
230+ crop_height , crop_width = size
231+ crop_height , crop_width = int (crop_height ), int (crop_width )
232+
233+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
234+ top = int (math .ceil ((orig_height - crop_height ) / 2 ))
235+ bottom = top + crop_height
236+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
237+ left = int (math .ceil ((orig_width - crop_width ) / 2 ))
238+ right = left + crop_width
239+
240+ # Check if cropped area is within image boundaries
241+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width :
242+ image = image [..., top :bottom , left :right ]
243+ image = to_channel_dimension_format (image , output_data_format , ChannelDimension .FIRST )
244+ return image
245+
246+ # Otherwise, we may need to pad if the image is too small. Oh joy...
247+ new_height = max (crop_height , orig_height )
248+ new_width = max (crop_width , orig_width )
249+ new_shape = image .shape [:- 2 ] + (new_height , new_width )
250+ new_image = np .zeros_like (image , shape = new_shape )
251+
252+ # If the image is too small, pad it with zeros
253+ top_pad = math .ceil ((new_height - orig_height ) / 2 )
254+ bottom_pad = top_pad + orig_height
255+ left_pad = math .ceil ((new_width - orig_width ) / 2 )
256+ right_pad = left_pad + orig_width
257+ new_image [..., top_pad :bottom_pad , left_pad :right_pad ] = image
258+
259+ top += top_pad
260+ bottom += top_pad
261+ left += left_pad
262+ right += left_pad
263+
264+ new_image = new_image [..., max (0 , top ) : min (new_height , bottom ), max (0 , left ) : min (new_width , right )]
265+ new_image = to_channel_dimension_format (new_image , output_data_format , ChannelDimension .FIRST )
266+
267+ if not return_numpy :
268+ new_image = to_pil_image (new_image )
269+
270+ return new_image
271+
272+
122273class ImageBindImageProcessor (BaseImageProcessor ):
123274 r"""
124275 Constructs an ImageBind image processor.
@@ -242,6 +393,38 @@ def __init__(
242393 # `shortest_edge` key.
243394 delattr (self , "use_square_size" )
244395
396+ def video_resize (
397+ self ,
398+ frames : List [np .ndarray ],
399+ size : Dict [str , int ],
400+ resampling : PILImageResampling = PILImageResampling .BILINEAR ,
401+ data_format : Optional [Union [str , ChannelDimension ]] = None ,
402+ input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
403+ ) -> List [np .ndarray ]:
404+ default_to_square = True
405+ if "shortest_edge" in size :
406+ size = size ["shortest_edge" ]
407+ default_to_square = False
408+ elif "height" in size and "width" in size :
409+ size = (size ["height" ], size ["width" ])
410+ else :
411+ raise ValueError ("Size must contain either 'shortest_edge' or 'height' and 'width'." )
412+
413+ output_size = get_resize_output_image_size (
414+ frames [0 ],
415+ size = size ,
416+ default_to_square = default_to_square ,
417+ input_data_format = input_data_format ,
418+ )
419+
420+ return video_resize (
421+ frames = frames ,
422+ size = output_size ,
423+ resampling = resampling ,
424+ data_format = data_format ,
425+ input_data_format = input_data_format ,
426+ )
427+
245428 # Copied from models.clip.image_processing_clip.CLIPImageProcessor.resize
246429 def resize (
247430 self ,
@@ -327,10 +510,49 @@ def chunk(
327510
328511 return all_clips
329512
330- # Copied from models.clip.image_processing_clip.CLIPImageProcessor.preprocess with preprocess->_preprocess_image
513+ def center_crop (
514+ self ,
515+ image : np .ndarray ,
516+ size : Dict [str , int ],
517+ data_format : Optional [Union [str , ChannelDimension ]] = None ,
518+ input_data_format : Optional [Union [str , ChannelDimension ]] = None ,
519+ ** kwargs ,
520+ ) -> np .ndarray :
521+ """
522+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
523+ any edge, the image is padded with 0's and then center cropped.
524+
525+ Args:
526+ image (`np.ndarray`):
527+ Image to center crop.
528+ size (`Dict[str, int]`):
529+ Size of the output image.
530+ data_format (`str` or `ChannelDimension`, *optional*):
531+ The channel dimension format for the output image. If unset, the channel dimension format of the input
532+ image is used. Can be one of:
533+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
534+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
535+ input_data_format (`ChannelDimension` or `str`, *optional*):
536+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
537+ from the input image. Can be one of:
538+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
539+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
540+ """
541+ size = get_size_dict (size )
542+ if "height" not in size or "width" not in size :
543+ raise ValueError (f"The size dictionary must have keys 'height' and 'width'. Got { size .keys ()} " )
544+ return modified_center_crop (
545+ image ,
546+ size = (size ["height" ], size ["width" ]),
547+ data_format = data_format ,
548+ input_data_format = input_data_format ,
549+ ** kwargs ,
550+ )
551+
331552 def _preprocess_image (
332553 self ,
333554 images : ImageInput ,
555+ is_video : bool = False ,
334556 do_resize : bool = None ,
335557 size : Dict [str , int ] = None ,
336558 resample : PILImageResampling = None ,
@@ -375,10 +597,15 @@ def _preprocess_image(
375597 input_data_format = infer_channel_dimension_format (images [0 ])
376598
377599 if do_resize :
378- images = [
379- self .resize (image = image , size = size , resample = resample , input_data_format = input_data_format )
380- for image in images
381- ]
600+ if is_video :
601+ images = self .video_resize (
602+ frames = images , size = size , resampling = resample , input_data_format = input_data_format
603+ )
604+ else :
605+ images = [
606+ self .resize (image = image , size = size , resample = resample , input_data_format = input_data_format )
607+ for image in images
608+ ]
382609
383610 if do_center_crop :
384611 images = [
@@ -403,7 +630,6 @@ def _preprocess_image(
403630
404631 return images
405632
406- # Ignore copy
407633 def preprocess (
408634 self ,
409635 images : Optional [ImageInput ] = None ,
@@ -565,6 +791,7 @@ def preprocess(
565791 _pixel_values = [
566792 self ._preprocess_image (
567793 images = clip ,
794+ is_video = True ,
568795 do_resize = do_resize ,
569796 size = size ,
570797 resample = PILImageResampling .BILINEAR ,
@@ -601,7 +828,7 @@ def preprocess(
601828 )
602829 ]
603830
604- # Avoid List[List[List[np.ndarray]]]
831+ # Avoid List[List[List[np.ndarray]]] for performance reasons
605832 _pixel_values = np .stack (_pixel_values )
606833 # Make it shape (num_chunks, num_channels, num_frames_per_chunk, height, width)
607834 _pixel_values = np .swapaxes (_pixel_values , 1 , 2 )
0 commit comments