Skip to content

Commit 0125cbe

Browse files
authored
Merge pull request #41 from ai-forever/v1.0
docs: update filters documentation
2 parents 7dbbf81 + f5ca4ff commit 0125cbe

File tree

8 files changed

+1020
-7
lines changed

8 files changed

+1020
-7
lines changed

DPF/filters/images/hash_filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
def get_phash(pil_img: Image.Image, hash_size: int = 8, highfreq_factor: int = 4) -> str:
1414
img_size = hash_size * highfreq_factor
15-
image_array = np.array(pil_img.resize((img_size, img_size), Image.LANCZOS))
15+
image_array = np.array(pil_img.resize((img_size, img_size), Image.Resampling.LANCZOS))
1616

1717
dct_coef = dct(dct(image_array, axis=0), axis=1)
1818
dct_reduced_coef = dct_coef[:hash_size, :hash_size]

DPF/filters/images/info_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_image_info(img_bytes: bytes, data: dict[str, Any], key_column: str) -> I
3131

3232
try:
3333
pil_img = Image.open(BytesIO(img_bytes))
34-
pil_img.load()
34+
pil_img.load() # type: ignore
3535

3636
arr = np.array(pil_img)
3737

DPF/filters/videos/image_filter_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def preprocess_data(
5656
frame = iio.imread(io.BytesIO(video_bytes), index=frame_index, plugin="pyav")
5757

5858
buff = io.BytesIO()
59-
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95)
59+
Image.fromarray(frame).convert('RGB').save(buff, format='JPEG', quality=95) # type: ignore
6060
modality2data['image'] = buff.getvalue()
6161
metadata[self.image_filter.key_column] = ''
6262
return key, self.image_filter.preprocess_data(modality2data, metadata)

DPF/utils/image_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
def read_image_rgb(path: str, force_rgb: bool = True) -> Image.Image:
77
pil_img = Image.open(path)
8-
pil_img.load()
8+
pil_img.load() # type: ignore
99
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
1010
pil_img = pil_img.convert("RGBA")
1111
if force_rgb:
@@ -15,7 +15,7 @@ def read_image_rgb(path: str, force_rgb: bool = True) -> Image.Image:
1515

1616
def read_image_rgb_from_bytes(img_bytes: bytes, force_rgb: bool = True) -> Image.Image:
1717
pil_img = Image.open(BytesIO(img_bytes))
18-
pil_img.load()
18+
pil_img.load() # type: ignore
1919
if pil_img.format == "PNG" and pil_img.mode != "RGBA":
2020
pil_img = pil_img.convert("RGBA")
2121
if force_rgb:

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ cd DataProcessingFramework
1919
pip install .
2020
```
2121

22+
Extra requirements: `filters`, `dev`, `llava`, `video_llava`
23+
24+
To install extra requirements run: `pip install .[filters]`
25+
2226
## Overview
2327

2428
Framework supports following features:
@@ -31,6 +35,9 @@ Framework supports following features:
3135

3236
DPF allows you to easily filter datasets and add new metadata.
3337
For example, the code below generates synthetic captions for images in shards on remote s3 storage and updates dataset metadata without downloading shards:
38+
39+
Before running the example below, install extra requirements: `pip install DPF[filters,llava]`
40+
3441
```python
3542
from DPF import S3Connector, DatasetReader, ShardsDatasetConfig
3643

docs/filters.md

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,5 +86,136 @@ You can find usage examples [there](../examples).
8686

8787
### Creating new filter
8888

89-
TODO
89+
To add your filter, you should create new filter class.
90+
If your filter uses only data from columns (e.g. _text_ modality), you should inherit your class from [ColumnFilter class](../DPF/filters/column_filter.py)
91+
If your filter uses data from files, you should inherit your class from [DataFilter class](../DPF/filters/data_filter.py)
9092

93+
#### Creating DataFilter
94+
95+
To create a new datafilter, add new file in a folder with the modality used by your filter.
96+
For example, if your filter uses _images_ modality, create file in [DPF/filters/images/](../DPF/filters/images) folder.
97+
If your filter uses _texts_ and _images_ modality, create file in [DPF/filters/text2image/](../DPF/filters/text2image) and so on.
98+
99+
Inherit you filter from corresponding `DataFilter` class in modality folder:
100+
- [DPF/filters/images/img_filter.py](../DPF/filters/images/img_filter.py) for _images_
101+
- [DPF/filters/text2image/t2i_filter.py](../DPF/filters/text2image/t2i_filter.py) for _texts_ and _images_
102+
- [DPF/filters/videos/video_filter.py](../DPF/filters/videos/video_filter.py) for _videos_
103+
104+
Then you should implement `result_columns`, `dataloader_kwargs` properties and `preprocess_data`, `process_batch` methods.
105+
- `result_columns` - list of result columns that filter adds to a DataFrame
106+
- `dataloader_kwargs` - parameters for a pytorch dataloader
107+
- `preprocess_data` - method where data preprocessing is implemented. This method is passed to dataloader and preprocessing runs in multiple processes. Do not use cuda operations in this method.
108+
- `process_batch` - method where batch is processed with model
109+
110+
For more information run:
111+
```python
112+
from DPF.filters import DataFilter
113+
help(DataFilter)
114+
```
115+
116+
**Example of custom DataFilter:**
117+
```python
118+
from typing import Any
119+
120+
from DPF.filters.images.img_filter import ImageFilter
121+
from DPF.types import ModalityToDataMapping
122+
123+
class PHashFilter(ImageFilter):
124+
def __init__(
125+
self,
126+
sim_hash_size: int = 8,
127+
workers: int = 16,
128+
pbar: bool = True,
129+
_pbar_position: int = 0
130+
):
131+
super().__init__(pbar, _pbar_position)
132+
self.num_workers = workers
133+
self.sim_hash_size = sim_hash_size
134+
135+
@property
136+
def result_columns(self) -> list[str]:
137+
return [f"image_phash_{self.sim_hash_size}"]
138+
139+
@property
140+
def dataloader_kwargs(self) -> dict[str, Any]:
141+
return {"num_workers": self.num_workers, "batch_size": 1, "drop_last": False}
142+
143+
def preprocess_data(
144+
self,
145+
modality2data: ModalityToDataMapping,
146+
metadata: dict[str, Any]
147+
) -> Any:
148+
key = metadata[self.key_column]
149+
img_simhash = get_phash(
150+
read_image_rgb_from_bytes(modality2data['image']),
151+
hash_size=self.sim_hash_size
152+
)
153+
return key, img_simhash
154+
155+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
156+
df_batch_labels = self._get_dict_from_schema()
157+
158+
keys, img_simhashes = list(zip(*batch))
159+
df_batch_labels[self.key_column].extend(keys)
160+
df_batch_labels[f"image_phash_{self.sim_hash_size}"].extend(img_simhashes)
161+
162+
return df_batch_labels
163+
```
164+
165+
This filter reads images and calculates PHash **in dataloader**.
166+
Then dataloader returns PHash strings and these strings are added in result dataframe.
167+
168+
#### Creating ColumnFilter
169+
170+
To create a new columnfilter, add new file in a folder with the modality used by your filter.
171+
Inherit your class from [ColumnFilter](../DPF/filters/column_filter.py) class.
172+
173+
Then you should implement `result_columns`, `columns_to_process` properties and `process_sample` methods.
174+
- `result_columns` - list of result columns that filter adds to a DataFrame
175+
- `columns_to_process` - columns in original dataframe used for processing. These columns are being passed in method
176+
- `process_sample` - method that processes one sample of data.
177+
178+
For more information run:
179+
```python
180+
from DPF.filters import ColumnFilter
181+
help(ColumnFilter)
182+
```
183+
184+
**Example of custom ColumnFilter:**
185+
```python
186+
from typing import Any
187+
from py3langid.langid import MODEL_FILE, LanguageIdentifier
188+
from DPF.filters import ColumnFilter
189+
190+
class LangFilter(ColumnFilter):
191+
"""
192+
LangFilter class
193+
"""
194+
195+
def __init__(
196+
self,
197+
text_column_name: str = "text",
198+
workers: int = 16,
199+
pbar: bool = True
200+
):
201+
super().__init__(workers, pbar)
202+
self.lang_identifier = LanguageIdentifier.from_pickled_model(
203+
MODEL_FILE, norm_probs=True
204+
)
205+
self.text_column_name = text_column_name
206+
207+
@property
208+
def columns_to_process(self) -> list[str]:
209+
return [self.text_column_name]
210+
211+
@property
212+
def result_columns(self) -> list[str]:
213+
return ["lang", "lang_score"]
214+
215+
def process_sample(self, sample: dict[str, Any]) -> list[Any]:
216+
lg, score = self.lang_identifier.classify(sample[self.text_column_name])
217+
return [lg, round(score, 2)]
218+
```
219+
220+
This filter creates 2 new columns: `lang` and `lang_score`.
221+
It uses column with text name to identify the language of a text.

docs/processor.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Dataset processor supports following features:
1212
from DPF import ShardsDatasetConfig, DatasetReader
1313

1414
config = ShardsDatasetConfig.from_path_and_columns(
15-
'examples/example_dataset/',
15+
'examples/example_dataset',
1616
image_name_col='image_name',
1717
text_col='caption'
1818
)

0 commit comments

Comments
 (0)