Skip to content

Commit c4c0679

Browse files
authored
Merge pull request #61 from ai-forever/better_filters
upgrade multigpu and image adapter for video filters
2 parents a5bd1e3 + 0a15bc5 commit c4c0679

File tree

5 files changed

+126
-17
lines changed

5 files changed

+126
-17
lines changed

DPF/filters/multigpu_filter.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import multiprocessing
22
from multiprocessing import Manager
3-
from typing import Any, Union
3+
from typing import Any, Callable, Optional, Union
44

55
import numpy as np
66
import pandas as pd
@@ -21,14 +21,19 @@ def run_one_process(
2121
i: int,
2222
index: pd.Series,
2323
results: list[pd.DataFrame],
24-
filter_class: type[DataFilter],
25-
filter_kwargs: dict[str, Any],
24+
filter_class: Optional[type[DataFilter]],
25+
filter_kwargs: Optional[dict[str, Any]],
26+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]],
2627
device: Union[str, torch.device],
2728
filter_run_kwargs: dict[str, Any]
2829
) -> None:
2930
reader = DatasetReader(connector=connector)
3031
processor = reader.from_df(config, df)
31-
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
32+
if datafilter_init_fn:
33+
datafilter = datafilter_init_fn(i, device)
34+
else:
35+
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
36+
3237
datafilter._created_by_multigpu_data_filter = True
3338
processor.apply_data_filter(datafilter, **filter_run_kwargs)
3439
res = processor.df
@@ -44,26 +49,34 @@ class MultiGPUDataFilter:
4449
def __init__(
4550
self,
4651
devices: list[Union[torch.device, str]],
47-
datafilter_class: type[DataFilter],
48-
datafilter_params: dict[str, Any]
52+
datafilter_class: Optional[type[DataFilter]] = None,
53+
datafilter_params: Optional[dict[str, Any]] = None,
54+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
4955
):
5056
"""
5157
Parameters
5258
----------
5359
devices: list[Union[torch.device, str]]
5460
List of devices to run datafilter on
55-
datafilter_class: type[DataFilter]
61+
datafilter_class: Optional[type[DataFilter]] = None
5662
Class of datafilter to use
57-
datafilter_params: dict[str, Any]
63+
datafilter_params: Optional[dict[str, Any]] = None
5864
Parameters for datafilter_class initialization
65+
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
66+
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
5967
"""
6068
self.filter_class = datafilter_class
6169
self.filter_params = datafilter_params
70+
self.datafilter_init_fn = datafilter_init_fn
71+
assert self.datafilter_init_fn or self.filter_class, "One method of filter initialization should be specified"
6272
self.devices = devices
6373
self.num_parts = len(devices)
6474

6575
# getting result columns names
66-
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
76+
if self.datafilter_init_fn:
77+
datafilter = self.datafilter_init_fn(0, devices[0])
78+
else:
79+
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
6780
self._result_columns = datafilter.result_columns
6881
del datafilter
6982
torch.cuda.empty_cache()
@@ -113,6 +126,7 @@ def run(
113126
shared_results,
114127
self.filter_class,
115128
self.filter_params,
129+
self.datafilter_init_fn,
116130
self.devices[i],
117131
filter_run_kwargs
118132
)

DPF/filters/videos/image_filter_adapter.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
2-
from typing import Any
2+
from collections.abc import Iterable
3+
from typing import Any, Callable
34

45
import imageio.v3 as iio
56
from PIL import Image
@@ -68,7 +69,7 @@ def preprocess_data(
6869
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")
6970

7071
buff = io.BytesIO()
71-
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95)
72+
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
7273
modality2data['image'] = buff.getvalue()
7374
metadata[self.image_filter.key_column] = ''
7475
return key, self.image_filter.preprocess_data(modality2data, metadata)
@@ -82,3 +83,97 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
8283
for colname in self.schema[1:]:
8384
df_batch_labels[colname].extend(df_batch_labels_images[colname])
8485
return df_batch_labels
86+
87+
88+
def chunks(lst: list[Any], n: int) -> Iterable[list[Any]]:
89+
for i in range(0, len(lst), n):
90+
yield lst[i:i + n]
91+
92+
93+
class MultiFrameImageFilterAdapter(VideoFilter):
94+
"""
95+
Runs an ImageFilter on several frames from video
96+
97+
Parameters
98+
----------
99+
image_filter: ImageFilter
100+
Image filter to apply
101+
video_frames: list[float]
102+
List of positions of frames to use
103+
For example 0 means first frame, 0.5 means central frame and 1 means last frame
104+
workers: int = 8
105+
Number of pytorch dataloader workers
106+
pbar: bool = True
107+
Whether to show progress bar
108+
"""
109+
110+
def __init__(
111+
self,
112+
image_filter: ImageFilter,
113+
video_frames: list[float],
114+
reduce_results_fn: Callable[[str, list[Any]], Any],
115+
batch_size: int = 8,
116+
workers: int = 8,
117+
pbar: bool = True,
118+
_pbar_position: int = 0
119+
):
120+
super().__init__(pbar, _pbar_position)
121+
self.image_filter = image_filter
122+
self.video_frames = video_frames
123+
self.reduce_results_fn = reduce_results_fn
124+
self.batch_size = batch_size
125+
self.num_workers = workers
126+
127+
@property
128+
def result_columns(self) -> list[str]:
129+
return self.image_filter.result_columns
130+
131+
@property
132+
def dataloader_kwargs(self) -> dict[str, Any]:
133+
return {
134+
"num_workers": self.num_workers,
135+
"batch_size": 1,
136+
"drop_last": False,
137+
}
138+
139+
def preprocess_data(
140+
self,
141+
modality2data: ModalityToDataMapping,
142+
metadata: dict[str, Any]
143+
) -> Any:
144+
key = metadata[self.key_column]
145+
146+
video_bytes = modality2data['video']
147+
meta = iio.immeta(io.BytesIO(video_bytes), plugin="pyav")
148+
fps = meta['fps']
149+
duration = meta['duration']
150+
total_frames = int(fps*duration)
151+
152+
preprocessed_data = []
153+
for video_frame_pos in self.video_frames:
154+
frame_index = min(int(total_frames*video_frame_pos), total_frames-1)
155+
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")
156+
157+
buff = io.BytesIO()
158+
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
159+
modality2data['image'] = buff.getvalue()
160+
metadata[self.image_filter.key_column] = ''
161+
162+
preprocessed_data.append(self.image_filter.preprocess_data(modality2data, metadata))
163+
164+
return key, preprocessed_data
165+
166+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
167+
df_batch_labels = self._get_dict_from_schema()
168+
169+
key, data = batch[0]
170+
df_batch_labels_images = self._get_dict_from_schema()
171+
for batched_preprocessed_data in chunks(data, self.batch_size):
172+
df_batch_labels_images_batch = self.image_filter.process_batch(batched_preprocessed_data)
173+
for colname in self.result_columns:
174+
df_batch_labels_images[colname].extend(df_batch_labels_images_batch[colname])
175+
176+
df_batch_labels[self.key_column].append(key)
177+
for colname in self.result_columns:
178+
df_batch_labels[colname].extend([self.reduce_results_fn(colname, df_batch_labels_images[colname])])
179+
return df_batch_labels

DPF/filters/videos/pllava_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def load_video(video_bytes: BytesIO, num_segments: int = 8, return_msg: bool = F
3131
frame_indices = get_index(num_frames, num_segments)
3232
images_group = []
3333
for frame_index in frame_indices:
34-
img = Image.fromarray(vr[frame_index].asnumpy())
34+
img = Image.fromarray(vr[frame_index].asnumpy()) # type: ignore
3535
images_group.append(transforms(img))
3636
if return_msg:
3737
fps = float(vr.get_avg_fps())

DPF/transforms/image_resize_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _process_filepath(self, data: TransformsFileData) -> TransformsFileData:
4141
width, height = self.resizer.get_new_size(img.width, img.height)
4242

4343
if (width, height) != (img.width, img.height):
44-
img = img.resize((width, height)) # type: ignore
44+
img = img.resize((width, height))
4545
img.save(filepath, format=self.img_format)
4646

4747
return TransformsFileData(filepath, {'width': width, 'height': height})

DPF/utils/image_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ def read_image_rgb(path: str, force_rgb: bool = True) -> Image.Image:
77
pil_img = Image.open(path)
88
pil_img.load() # type: ignore
99
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
10-
pil_img = pil_img.convert("RGBA") # type: ignore
10+
pil_img = pil_img.convert("RGBA")
1111
if force_rgb:
12-
pil_img = pil_img.convert("RGB") # type: ignore
12+
pil_img = pil_img.convert("RGB")
1313
return pil_img
1414

1515

1616
def read_image_rgb_from_bytes(img_bytes: bytes, force_rgb: bool = True) -> Image.Image:
1717
pil_img = Image.open(BytesIO(img_bytes))
1818
pil_img.load() # type: ignore
1919
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
20-
pil_img = pil_img.convert("RGBA") # type: ignore
20+
pil_img = pil_img.convert("RGBA")
2121
if force_rgb:
22-
pil_img = pil_img.convert("RGB") # type: ignore
22+
pil_img = pil_img.convert("RGB")
2323
return pil_img

0 commit comments

Comments
 (0)