Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions DPF/filters/multigpu_filter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
from multiprocessing import Manager
from typing import Any, Union
from typing import Any, Callable, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -21,14 +21,19 @@ def run_one_process(
i: int,
index: pd.Series,
results: list[pd.DataFrame],
filter_class: type[DataFilter],
filter_kwargs: dict[str, Any],
filter_class: Optional[type[DataFilter]],
filter_kwargs: Optional[dict[str, Any]],
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]],
device: Union[str, torch.device],
filter_run_kwargs: dict[str, Any]
) -> None:
reader = DatasetReader(connector=connector)
processor = reader.from_df(config, df)
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore
if datafilter_init_fn:
datafilter = datafilter_init_fn(i, device)
else:
datafilter = filter_class(**filter_kwargs, _pbar_position=i, device=device) # type: ignore

datafilter._created_by_multigpu_data_filter = True
processor.apply_data_filter(datafilter, **filter_run_kwargs)
res = processor.df
Expand All @@ -44,26 +49,34 @@ class MultiGPUDataFilter:
def __init__(
self,
devices: list[Union[torch.device, str]],
datafilter_class: type[DataFilter],
datafilter_params: dict[str, Any]
datafilter_class: Optional[type[DataFilter]] = None,
datafilter_params: Optional[dict[str, Any]] = None,
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
):
"""
Parameters
----------
devices: list[Union[torch.device, str]]
List of devices to run datafilter on
datafilter_class: type[DataFilter]
datafilter_class: Optional[type[DataFilter]] = None
Class of datafilter to use
datafilter_params: dict[str, Any]
datafilter_params: Optional[dict[str, Any]] = None
Parameters for datafilter_class initialization
datafilter_init_fn: Optional[Callable[[int, Union[str, torch.device]], DataFilter]] = None
Initialization function for a datafilter. Takes _pbar_position as first arg and device as a second arg
"""
self.filter_class = datafilter_class
self.filter_params = datafilter_params
self.datafilter_init_fn = datafilter_init_fn
assert self.datafilter_init_fn or self.filter_class, "One method of filter initialization should be specified"
self.devices = devices
self.num_parts = len(devices)

# getting result columns names
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
if self.datafilter_init_fn:
datafilter = self.datafilter_init_fn(0, devices[0])
else:
datafilter = self.filter_class(**self.filter_params, device=devices[0]) # type: ignore
self._result_columns = datafilter.result_columns
del datafilter
torch.cuda.empty_cache()
Expand Down Expand Up @@ -113,6 +126,7 @@ def run(
shared_results,
self.filter_class,
self.filter_params,
self.datafilter_init_fn,
self.devices[i],
filter_run_kwargs
)
Expand Down
99 changes: 97 additions & 2 deletions DPF/filters/videos/image_filter_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from typing import Any
from collections.abc import Iterable
from typing import Any, Callable

import imageio.v3 as iio
from PIL import Image
Expand Down Expand Up @@ -68,7 +69,7 @@ def preprocess_data(
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")

buff = io.BytesIO()
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95)
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
modality2data['image'] = buff.getvalue()
metadata[self.image_filter.key_column] = ''
return key, self.image_filter.preprocess_data(modality2data, metadata)
Expand All @@ -82,3 +83,97 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
for colname in self.schema[1:]:
df_batch_labels[colname].extend(df_batch_labels_images[colname])
return df_batch_labels


def chunks(lst: list[Any], n: int) -> Iterable[list[Any]]:
for i in range(0, len(lst), n):
yield lst[i:i + n]


class MultiFrameImageFilterAdapter(VideoFilter):
"""
Runs an ImageFilter on several frames from video

Parameters
----------
image_filter: ImageFilter
Image filter to apply
video_frames: list[float]
List of positions of frames to use
For example 0 means first frame, 0.5 means central frame and 1 means last frame
workers: int = 8
Number of pytorch dataloader workers
pbar: bool = True
Whether to show progress bar
"""

def __init__(
self,
image_filter: ImageFilter,
video_frames: list[float],
reduce_results_fn: Callable[[str, list[Any]], Any],
batch_size: int = 8,
workers: int = 8,
pbar: bool = True,
_pbar_position: int = 0
):
super().__init__(pbar, _pbar_position)
self.image_filter = image_filter
self.video_frames = video_frames
self.reduce_results_fn = reduce_results_fn
self.batch_size = batch_size
self.num_workers = workers

@property
def result_columns(self) -> list[str]:
return self.image_filter.result_columns

@property
def dataloader_kwargs(self) -> dict[str, Any]:
return {
"num_workers": self.num_workers,
"batch_size": 1,
"drop_last": False,
}

def preprocess_data(
self,
modality2data: ModalityToDataMapping,
metadata: dict[str, Any]
) -> Any:
key = metadata[self.key_column]

video_bytes = modality2data['video']
meta = iio.immeta(io.BytesIO(video_bytes), plugin="pyav")
fps = meta['fps']
duration = meta['duration']
total_frames = int(fps*duration)

preprocessed_data = []
for video_frame_pos in self.video_frames:
frame_index = min(int(total_frames*video_frame_pos), total_frames-1)
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")

buff = io.BytesIO()
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
modality2data['image'] = buff.getvalue()
metadata[self.image_filter.key_column] = ''

preprocessed_data.append(self.image_filter.preprocess_data(modality2data, metadata))

return key, preprocessed_data

def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
df_batch_labels = self._get_dict_from_schema()

key, data = batch[0]
df_batch_labels_images = self._get_dict_from_schema()
for batched_preprocessed_data in chunks(data, self.batch_size):
df_batch_labels_images_batch = self.image_filter.process_batch(batched_preprocessed_data)
for colname in self.result_columns:
df_batch_labels_images[colname].extend(df_batch_labels_images_batch[colname])

df_batch_labels[self.key_column].append(key)
for colname in self.result_columns:
df_batch_labels[colname].extend([self.reduce_results_fn(colname, df_batch_labels_images[colname])])
return df_batch_labels
2 changes: 1 addition & 1 deletion DPF/filters/videos/pllava_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def load_video(video_bytes: BytesIO, num_segments: int = 8, return_msg: bool = F
frame_indices = get_index(num_frames, num_segments)
images_group = []
for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy())
img = Image.fromarray(vr[frame_index].asnumpy()) # type: ignore
images_group.append(transforms(img))
if return_msg:
fps = float(vr.get_avg_fps())
Expand Down
2 changes: 1 addition & 1 deletion DPF/transforms/image_resize_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _process_filepath(self, data: TransformsFileData) -> TransformsFileData:
width, height = self.resizer.get_new_size(img.width, img.height)

if (width, height) != (img.width, img.height):
img = img.resize((width, height)) # type: ignore
img = img.resize((width, height))
img.save(filepath, format=self.img_format)

return TransformsFileData(filepath, {'width': width, 'height': height})
8 changes: 4 additions & 4 deletions DPF/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ def read_image_rgb(path: str, force_rgb: bool = True) -> Image.Image:
pil_img = Image.open(path)
pil_img.load() # type: ignore
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
pil_img = pil_img.convert("RGBA") # type: ignore
pil_img = pil_img.convert("RGBA")
if force_rgb:
pil_img = pil_img.convert("RGB") # type: ignore
pil_img = pil_img.convert("RGB")
return pil_img


def read_image_rgb_from_bytes(img_bytes: bytes, force_rgb: bool = True) -> Image.Image:
pil_img = Image.open(BytesIO(img_bytes))
pil_img.load() # type: ignore
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
pil_img = pil_img.convert("RGBA") # type: ignore
pil_img = pil_img.convert("RGBA")
if force_rgb:
pil_img = pil_img.convert("RGB") # type: ignore
pil_img = pil_img.convert("RGB")
return pil_img