Skip to content

Commit b468d82

Browse files
authored
Merge pull request #57 from ai-forever/dev
1.1.0
2 parents 5abeab0 + 5b00279 commit b468d82

File tree

91 files changed

+15153
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+15153
-3
lines changed

DPF/configs/files_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def from_path_and_columns(
5858
path: str,
5959
image_path_col: Optional[str] = None,
6060
video_path_col: Optional[str] = None,
61+
audio_path_col: Optional[str] = None,
6162
text_col: Optional[str] = None,
6263
) -> "FilesDatasetConfig":
6364
"""
@@ -69,6 +70,8 @@ def from_path_and_columns(
6970
Name of column with image paths
7071
video_path_col: Optional[str] = None
7172
Name of column with video paths
73+
audio_path_col: Optional[str] = None
74+
Name of column with audio paths
7275
text_col: Optional[str] = None
7376
Name of column with text
7477
@@ -82,6 +85,8 @@ def from_path_and_columns(
8285
datatypes.append(FileDataType(MODALITIES['image'], image_path_col))
8386
if video_path_col:
8487
datatypes.append(FileDataType(MODALITIES['video'], video_path_col))
88+
if audio_path_col:
89+
datatypes.append(FileDataType(MODALITIES['audio'], audio_path_col))
8590
if text_col:
8691
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
8792
assert len(datatypes) > 0, "At least one modality should be provided"

DPF/configs/sharded_files_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def from_path_and_columns(
3333
path: str,
3434
image_name_col: Optional[str] = None,
3535
video_name_col: Optional[str] = None,
36+
audio_name_col: Optional[str] = None,
3637
text_col: Optional[str] = None,
3738
datafiles_ext: str = "csv",
3839
) -> "ShardedFilesDatasetConfig":
@@ -45,6 +46,8 @@ def from_path_and_columns(
4546
Name of column with image filenames in shard
4647
video_name_col: Optional[str] = None
4748
Name of column with video filenames in shard
49+
audio_name_col: Optional[str] = None
50+
Name of column with audio filenames in shard
4851
text_col: Optional[str] = None
4952
Name of column with text
5053
datafiles_ext: str = "csv"
@@ -60,6 +63,8 @@ def from_path_and_columns(
6063
datatypes.append(ShardedDataType(MODALITIES['image'], image_name_col))
6164
if video_name_col:
6265
datatypes.append(ShardedDataType(MODALITIES['video'], video_name_col))
66+
if audio_name_col:
67+
datatypes.append(ShardedDataType(MODALITIES['audio'], audio_name_col))
6368
if text_col:
6469
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
6570
assert len(datatypes) > 0, "At least one modality should be provided"

DPF/configs/shards_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def from_path_and_columns(
3737
path: str,
3838
image_name_col: Optional[str] = None,
3939
video_name_col: Optional[str] = None,
40+
audio_name_col: Optional[str] = None,
4041
text_col: Optional[str] = None,
4142
archives_ext: str = "tar",
4243
datafiles_ext: str = "csv",
@@ -50,6 +51,8 @@ def from_path_and_columns(
5051
Name of column with image filenames in shard
5152
video_name_col: Optional[str] = None
5253
Name of column with video filenames in shard
54+
audio_name_col: Optional[str] = None
55+
Name of column with audio filenames in shard
5356
text_col: Optional[str] = None
5457
Name of column with text
5558
archives_ext: str = "tar"
@@ -67,6 +70,8 @@ def from_path_and_columns(
6770
datatypes.append(ShardedDataType(MODALITIES['image'], image_name_col))
6871
if video_name_col:
6972
datatypes.append(ShardedDataType(MODALITIES['video'], video_name_col))
73+
if audio_name_col:
74+
datatypes.append(ShardedDataType(MODALITIES['audio'], audio_name_col))
7075
if text_col:
7176
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
7277
assert len(datatypes) > 0, "At least one modality should be provided"

DPF/filters/audios/audio_filter.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import ABC
2+
3+
from DPF.filters.data_filter import DataFilter
4+
from DPF.modalities import MODALITIES, ModalityName
5+
6+
7+
class AudioFilter(DataFilter, ABC):
8+
"""
9+
Abstract class for all audio filters.
10+
"""
11+
12+
@property
13+
def modalities(self) -> list[ModalityName]:
14+
return ['audio']
15+
16+
@property
17+
def key_column(self) -> str:
18+
return MODALITIES['audio'].path_column
19+
20+
@property
21+
def metadata_columns(self) -> list[str]:
22+
return []

DPF/filters/audios/info_filter.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from dataclasses import dataclass
2+
from io import BytesIO
3+
from typing import Any, Optional
4+
5+
import soundfile as sf
6+
7+
from DPF.types import ModalityToDataMapping
8+
9+
from .audio_filter import AudioFilter
10+
11+
12+
@dataclass
13+
class AudioInfo:
14+
key: str
15+
is_correct: bool
16+
duration: Optional[float]
17+
sample_rate: Optional[int]
18+
error: Optional[str]
19+
20+
21+
def get_audio_info(audio_bytes: bytes, data: dict[str, Any], key_column: str) -> AudioInfo:
22+
"""
23+
Get info about audio
24+
"""
25+
key = data[key_column]
26+
27+
is_correct = True
28+
sample_rate, duration = None, None
29+
err_str = None
30+
31+
try:
32+
file = sf.SoundFile(BytesIO(audio_bytes))
33+
34+
sample_rate = file.samplerate
35+
duration = len(file) / sample_rate
36+
except Exception as err:
37+
is_correct = False
38+
err_str = str(err)
39+
40+
return AudioInfo(key, is_correct, duration, sample_rate, err_str)
41+
42+
43+
class AudioInfoFilter(AudioFilter):
44+
"""
45+
Filter for gathering basic info about audios (width, height, number of channels)
46+
47+
Parameters
48+
----------
49+
workers: int = 16
50+
Number of parallel dataloader workers
51+
pbar: bool = True
52+
Whether to show progress bar
53+
"""
54+
55+
def __init__(self, workers: int = 16, pbar: bool = True, _pbar_position: int = 0):
56+
super().__init__(pbar, _pbar_position)
57+
self.num_workers = workers
58+
59+
@property
60+
def result_columns(self) -> list[str]:
61+
return [
62+
"is_correct", "duration", "sample_rate", "error",
63+
]
64+
65+
@property
66+
def dataloader_kwargs(self) -> dict[str, Any]:
67+
return {
68+
"num_workers": self.num_workers,
69+
"batch_size": 1,
70+
"drop_last": False,
71+
}
72+
73+
def preprocess_data(
74+
self,
75+
modality2data: ModalityToDataMapping,
76+
metadata: dict[str, Any]
77+
) -> Any:
78+
return get_audio_info(modality2data['audio'], metadata, self.key_column)
79+
80+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
81+
df_batch_labels = self._get_dict_from_schema()
82+
83+
for image_info in batch:
84+
df_batch_labels[self.key_column].append(image_info.key)
85+
df_batch_labels["is_correct"].append(image_info.is_correct)
86+
df_batch_labels["duration"].append(image_info.duration)
87+
df_batch_labels["sample_rate"].append(image_info.sample_rate)
88+
df_batch_labels["error"].append(image_info.error)
89+
return df_batch_labels
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import re
2+
from typing import Any
3+
4+
import torch
5+
from torchvision import transforms as T
6+
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
7+
8+
from DPF.filters.images.img_filter import ImageFilter
9+
from DPF.types import ModalityToDataMapping
10+
from DPF.utils import read_image_rgb_from_bytes
11+
12+
13+
class Llava34b_Filter(ImageFilter):
14+
"""
15+
The filter implements a description of the images supplied to the input using a model llava-v1.6-34b-hf.
16+
"""
17+
18+
def __init__(
19+
self,
20+
model_path: str = 'llava-hf/llava-v1.6-34b-hf',
21+
workers: int = 16,
22+
batch_size: int = 8,
23+
prompt: str = 'detailed-long',
24+
device: str = "cuda:0",
25+
pbar: bool = True,
26+
crop_size_x: int = 336,
27+
crop_size_y: int = 336,
28+
resize: int = 336,
29+
_pbar_position: int = 0
30+
):
31+
super().__init__(pbar, _pbar_position)
32+
self.batch_size = batch_size
33+
self.num_workers = workers
34+
self.device = device
35+
self.crop_size_x = crop_size_x
36+
self.crop_size_y = crop_size_y
37+
self.resize = resize
38+
self.model_path = model_path
39+
self.prompt_to_use = prompt
40+
prompts = {
41+
'detailed-long': 'Please provide a caption for this image. Speak confidently and describe everything clearly. Do not lie and describe only what you can see',
42+
'pixart': 'Describe this image and its style in a very detailed manner',
43+
'short': 'Describe this image very shortly in 1-2 short sentences',
44+
'short-video': 'Describe this video very shortly in 1-2 short sentences. Describe what is happening in this video.'
45+
}
46+
self.input_ids = prompts[self.prompt_to_use]
47+
print(self.input_ids)
48+
self.prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n" + f"{self.input_ids}" + "<|im_end|><|im_start|>assistant\n"
49+
self.processor = LlavaNextProcessor.from_pretrained(model_path)
50+
self.model = LlavaNextForConditionalGeneration.from_pretrained(
51+
model_path,
52+
torch_dtype=torch.float16,
53+
low_cpu_mem_usage=True,
54+
attn_implementation="flash_attention_2",
55+
device_map=self.device
56+
)
57+
58+
@property
59+
def result_columns(self) -> list[str]:
60+
return [f"caption {self.model_path}"]
61+
62+
@property
63+
def dataloader_kwargs(self) -> dict[str, Any]:
64+
return {
65+
"num_workers": self.num_workers,
66+
"batch_size": self.batch_size,
67+
"drop_last": False,
68+
}
69+
70+
def preprocess_data(
71+
self,
72+
modality2data: ModalityToDataMapping,
73+
metadata: dict[str, Any]
74+
) -> Any:
75+
key = metadata[self.key_column]
76+
pil_img = read_image_rgb_from_bytes(
77+
modality2data['image']).convert('RGB')
78+
transform = T.Compose([
79+
T.Resize(self.resize),
80+
T.CenterCrop((self.crop_size_x,self.crop_size_y))
81+
])
82+
cropped_image = transform(pil_img)
83+
return key, cropped_image
84+
85+
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
86+
df_batch_labels = self._get_dict_from_schema()
87+
keys, images = list(zip(*batch))
88+
prompts = [self.prompt for _ in range(self.batch_size)]
89+
inputs = self.processor(prompts, list(
90+
images), return_tensors="pt").to(self.device)
91+
with torch.inference_mode():
92+
output_ids = self.model.generate(
93+
**inputs, max_new_tokens=512, use_cache=True)
94+
95+
all_outputs = []
96+
for i in range(output_ids.shape[0]):
97+
output = self.processor.decode(
98+
output_ids[i], skip_special_tokens=True, clean_up_tokenization_spaces=True)
99+
output = re.sub(r'.*?assistant', '', output, flags=re.DOTALL)
100+
output = re.sub(r'\n', '', output, count=1)
101+
all_outputs.append(output)
102+
103+
df_batch_labels[self.schema[1]].extend(all_outputs)
104+
df_batch_labels[self.key_column].extend(keys)
105+
106+
return df_batch_labels

DPF/filters/videos/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)