11import io
2- from typing import Any
2+ from collections .abc import Iterable
3+ from typing import Any , Callable
34
45import imageio .v3 as iio
56from PIL import Image
@@ -68,7 +69,7 @@ def preprocess_data(
6869 frame = iio .imread (io .BytesIO (video_bytes ), index = frame_index , plugin = "pyav" )
6970
7071 buff = io .BytesIO ()
71- Image .fromarray (frame ).convert ('RGB' ).save (buff , format = 'JPEG' , quality = 95 )
72+ Image .fromarray (frame ).convert ('RGB' ).save (buff , format = 'JPEG' , quality = 95 ) # type: ignore
7273 modality2data ['image' ] = buff .getvalue ()
7374 metadata [self .image_filter .key_column ] = ''
7475 return key , self .image_filter .preprocess_data (modality2data , metadata )
@@ -82,3 +83,97 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
8283 for colname in self .schema [1 :]:
8384 df_batch_labels [colname ].extend (df_batch_labels_images [colname ])
8485 return df_batch_labels
86+
87+
88+ def chunks (lst : list [Any ], n : int ) -> Iterable [list [Any ]]:
89+ for i in range (0 , len (lst ), n ):
90+ yield lst [i :i + n ]
91+
92+
93+ class MultiFrameImageFilterAdapter (VideoFilter ):
94+ """
95+ Runs an ImageFilter on several frames from video
96+
97+ Parameters
98+ ----------
99+ image_filter: ImageFilter
100+ Image filter to apply
101+ video_frames: list[float]
102+ List of positions of frames to use
103+ For example 0 means first frame, 0.5 means central frame and 1 means last frame
104+ workers: int = 8
105+ Number of pytorch dataloader workers
106+ pbar: bool = True
107+ Whether to show progress bar
108+ """
109+
110+ def __init__ (
111+ self ,
112+ image_filter : ImageFilter ,
113+ video_frames : list [float ],
114+ reduce_results_fn : Callable [[str , list [Any ]], Any ],
115+ batch_size : int = 8 ,
116+ workers : int = 8 ,
117+ pbar : bool = True ,
118+ _pbar_position : int = 0
119+ ):
120+ super ().__init__ (pbar , _pbar_position )
121+ self .image_filter = image_filter
122+ self .video_frames = video_frames
123+ self .reduce_results_fn = reduce_results_fn
124+ self .batch_size = batch_size
125+ self .num_workers = workers
126+
127+ @property
128+ def result_columns (self ) -> list [str ]:
129+ return self .image_filter .result_columns
130+
131+ @property
132+ def dataloader_kwargs (self ) -> dict [str , Any ]:
133+ return {
134+ "num_workers" : self .num_workers ,
135+ "batch_size" : 1 ,
136+ "drop_last" : False ,
137+ }
138+
139+ def preprocess_data (
140+ self ,
141+ modality2data : ModalityToDataMapping ,
142+ metadata : dict [str , Any ]
143+ ) -> Any :
144+ key = metadata [self .key_column ]
145+
146+ video_bytes = modality2data ['video' ]
147+ meta = iio .immeta (io .BytesIO (video_bytes ), plugin = "pyav" )
148+ fps = meta ['fps' ]
149+ duration = meta ['duration' ]
150+ total_frames = int (fps * duration )
151+
152+ preprocessed_data = []
153+ for video_frame_pos in self .video_frames :
154+ frame_index = min (int (total_frames * video_frame_pos ), total_frames - 1 )
155+ frame = iio .imread (io .BytesIO (video_bytes ), index = frame_index , plugin = "pyav" )
156+
157+ buff = io .BytesIO ()
158+ Image .fromarray (frame ).convert ('RGB' ).save (buff , format = 'JPEG' , quality = 95 ) # type: ignore
159+ modality2data ['image' ] = buff .getvalue ()
160+ metadata [self .image_filter .key_column ] = ''
161+
162+ preprocessed_data .append (self .image_filter .preprocess_data (modality2data , metadata ))
163+
164+ return key , preprocessed_data
165+
166+ def process_batch (self , batch : list [Any ]) -> dict [str , list [Any ]]:
167+ df_batch_labels = self ._get_dict_from_schema ()
168+
169+ key , data = batch [0 ]
170+ df_batch_labels_images = self ._get_dict_from_schema ()
171+ for batched_preprocessed_data in chunks (data , self .batch_size ):
172+ df_batch_labels_images_batch = self .image_filter .process_batch (batched_preprocessed_data )
173+ for colname in self .result_columns :
174+ df_batch_labels_images [colname ].extend (df_batch_labels_images_batch [colname ])
175+
176+ df_batch_labels [self .key_column ].append (key )
177+ for colname in self .result_columns :
178+ df_batch_labels [colname ].extend ([self .reduce_results_fn (colname , df_batch_labels_images [colname ])])
179+ return df_batch_labels
0 commit comments