Skip to content

Commit 2b56db2

Browse files
authored
Merge pull request #60 from ai-forever/complex_filter
feat: add complex filter
2 parents 5b00279 + b360c4a commit 2b56db2

File tree

4 files changed

+117
-1
lines changed

4 files changed

+117
-1
lines changed

DPF/filters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .column_filter import ColumnFilter
2+
from .complex_filter import ComplexDataFilter
23
from .data_filter import DataFilter

DPF/filters/complex_filter.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
4+
from DPF.filters.data_filter import DataFilter
5+
from DPF.modalities import ModalityName
6+
from DPF.types import ModalityToDataMapping
7+
8+
9+
@dataclass
10+
class ComplexFilterPreprocessedData:
11+
key: str
12+
preprocessed_values: dict[int, Any]
13+
14+
15+
class ComplexDataFilter(DataFilter):
16+
17+
def __init__(
18+
self,
19+
datafilters: list[DataFilter],
20+
workers: int,
21+
pbar: bool = True,
22+
_pbar_position: int = 0
23+
):
24+
super().__init__(pbar, _pbar_position)
25+
self.datafilters = datafilters
26+
self.workers = workers
27+
28+
assert len(self.datafilters) > 0
29+
assert all(
30+
i.key_column == self.datafilters[0].key_column for i in self.datafilters
31+
) # check all filters have same key col
32+
33+
@property
34+
def modalities(self) -> list[ModalityName]:
35+
modals = []
36+
for datafilter in self.datafilters:
37+
modals.extend(datafilter.modalities)
38+
return list(set(modals))
39+
40+
@property
41+
def key_column(self) -> str:
42+
return self.datafilters[0].key_column
43+
44+
@property
45+
def metadata_columns(self) -> list[str]:
46+
meta_cols = []
47+
for datafilter in self.datafilters:
48+
meta_cols.extend(datafilter.metadata_columns)
49+
return list(set(meta_cols))
50+
51+
@property
52+
def result_columns(self) -> list[str]:
53+
result_cols = []
54+
for datafilter in self.datafilters:
55+
result_cols.extend(datafilter.result_columns)
56+
return list(set(result_cols))
57+
58+
@property
59+
def dataloader_kwargs(self) -> dict[str, Any]:
60+
return {
61+
"num_workers": self.workers,
62+
"batch_size": 1,
63+
"drop_last": False,
64+
}
65+
66+
def preprocess_data(
67+
self,
68+
modality2data: ModalityToDataMapping,
69+
metadata: dict[str, Any]
70+
) -> Any:
71+
key = metadata[self.key_column]
72+
preprocessed_results = {}
73+
for ind, datafilter in enumerate(self.datafilters):
74+
preprocessed_results[ind] = datafilter.preprocess_data(modality2data, metadata)
75+
return ComplexFilterPreprocessedData(key, preprocessed_results)
76+
77+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
78+
results = {}
79+
preprocessed_data: ComplexFilterPreprocessedData = batch[0]
80+
for ind, datafilter in enumerate(self.datafilters):
81+
filter_results_data = datafilter.process_batch([preprocessed_data.preprocessed_values[ind]])
82+
for col, value in filter_results_data.items():
83+
results[col] = value
84+
results[self.key_column] = [preprocessed_data.key]
85+
86+
return results

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ dependencies = [
1818
"numpy",
1919
"soundfile",
2020
"scipy",
21-
"pillow",
21+
"pillow==10.3.0",
2222
"tqdm",
2323
"pandas",
2424
"pandarallel",

tests/test_complex_filter.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from DPF import DatasetReader
2+
from DPF.configs import ShardsDatasetConfig
3+
from DPF.filters import ComplexDataFilter
4+
from DPF.filters.images.hash_filters import PHashFilter
5+
from DPF.filters.images.info_filter import ImageInfoFilter
6+
7+
8+
def test_shards_complex_filter():
9+
path = 'tests/datasets/shards_correct'
10+
config = ShardsDatasetConfig.from_path_and_columns(
11+
path,
12+
image_name_col="image_name",
13+
text_col="caption"
14+
)
15+
16+
reader = DatasetReader()
17+
dataset = reader.read_from_config(config)
18+
phashfilter = PHashFilter(workers=1)
19+
infofilter = ImageInfoFilter(workers=1)
20+
21+
datafilter = ComplexDataFilter([phashfilter, infofilter], workers=2)
22+
dataset.apply_data_filter(datafilter)
23+
24+
assert not dataset.df['image_phash_8'].isna().any()
25+
assert dataset.df['is_correct'].all()
26+
assert dataset.df['error'].isna().all()
27+
assert not dataset.df['width'].isna().any()
28+
assert not dataset.df['height'].isna().any()
29+
assert not dataset.df['channels'].isna().any()

0 commit comments

Comments
 (0)