Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions DPF/configs/files_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def from_path_and_columns(
path: str,
image_path_col: Optional[str] = None,
video_path_col: Optional[str] = None,
audio_path_col: Optional[str] = None,
text_col: Optional[str] = None,
) -> "FilesDatasetConfig":
"""
Expand All @@ -69,6 +70,8 @@ def from_path_and_columns(
Name of column with image paths
video_path_col: Optional[str] = None
Name of column with video paths
audio_path_col: Optional[str] = None
Name of column with audio paths
text_col: Optional[str] = None
Name of column with text

Expand All @@ -82,6 +85,8 @@ def from_path_and_columns(
datatypes.append(FileDataType(MODALITIES['image'], image_path_col))
if video_path_col:
datatypes.append(FileDataType(MODALITIES['video'], video_path_col))
if audio_path_col:
datatypes.append(FileDataType(MODALITIES['audio'], audio_path_col))
if text_col:
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
assert len(datatypes) > 0, "At least one modality should be provided"
Expand Down
5 changes: 5 additions & 0 deletions DPF/configs/sharded_files_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def from_path_and_columns(
path: str,
image_name_col: Optional[str] = None,
video_name_col: Optional[str] = None,
audio_name_col: Optional[str] = None,
text_col: Optional[str] = None,
datafiles_ext: str = "csv",
) -> "ShardedFilesDatasetConfig":
Expand All @@ -45,6 +46,8 @@ def from_path_and_columns(
Name of column with image filenames in shard
video_name_col: Optional[str] = None
Name of column with video filenames in shard
audio_name_col: Optional[str] = None
Name of column with audio filenames in shard
text_col: Optional[str] = None
Name of column with text
datafiles_ext: str = "csv"
Expand All @@ -60,6 +63,8 @@ def from_path_and_columns(
datatypes.append(ShardedDataType(MODALITIES['image'], image_name_col))
if video_name_col:
datatypes.append(ShardedDataType(MODALITIES['video'], video_name_col))
if audio_name_col:
datatypes.append(ShardedDataType(MODALITIES['audio'], audio_name_col))
if text_col:
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
assert len(datatypes) > 0, "At least one modality should be provided"
Expand Down
5 changes: 5 additions & 0 deletions DPF/configs/shards_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def from_path_and_columns(
path: str,
image_name_col: Optional[str] = None,
video_name_col: Optional[str] = None,
audio_name_col: Optional[str] = None,
text_col: Optional[str] = None,
archives_ext: str = "tar",
datafiles_ext: str = "csv",
Expand All @@ -50,6 +51,8 @@ def from_path_and_columns(
Name of column with image filenames in shard
video_name_col: Optional[str] = None
Name of column with video filenames in shard
audio_name_col: Optional[str] = None
Name of column with audio filenames in shard
text_col: Optional[str] = None
Name of column with text
archives_ext: str = "tar"
Expand All @@ -67,6 +70,8 @@ def from_path_and_columns(
datatypes.append(ShardedDataType(MODALITIES['image'], image_name_col))
if video_name_col:
datatypes.append(ShardedDataType(MODALITIES['video'], video_name_col))
if audio_name_col:
datatypes.append(ShardedDataType(MODALITIES['audio'], audio_name_col))
if text_col:
datatypes.append(ColumnDataType(MODALITIES['text'], text_col))
assert len(datatypes) > 0, "At least one modality should be provided"
Expand Down
22 changes: 22 additions & 0 deletions DPF/filters/audios/audio_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from abc import ABC

from DPF.filters.data_filter import DataFilter
from DPF.modalities import MODALITIES, ModalityName


class AudioFilter(DataFilter, ABC):
"""
Abstract class for all audio filters.
"""

@property
def modalities(self) -> list[ModalityName]:
return ['audio']

@property
def key_column(self) -> str:
return MODALITIES['audio'].path_column

@property
def metadata_columns(self) -> list[str]:
return []
89 changes: 89 additions & 0 deletions DPF/filters/audios/info_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Optional

import soundfile as sf

from DPF.types import ModalityToDataMapping

from .audio_filter import AudioFilter


@dataclass
class AudioInfo:
key: str
is_correct: bool
duration: Optional[float]
sample_rate: Optional[int]
error: Optional[str]


def get_audio_info(audio_bytes: bytes, data: dict[str, Any], key_column: str) -> AudioInfo:
"""
Get info about audio
"""
key = data[key_column]

is_correct = True
sample_rate, duration = None, None
err_str = None

try:
file = sf.SoundFile(BytesIO(audio_bytes))

sample_rate = file.samplerate
duration = len(file) / sample_rate
except Exception as err:
is_correct = False
err_str = str(err)

return AudioInfo(key, is_correct, duration, sample_rate, err_str)


class AudioInfoFilter(AudioFilter):
"""
Filter for gathering basic info about audios (width, height, number of channels)

Parameters
----------
workers: int = 16
Number of parallel dataloader workers
pbar: bool = True
Whether to show progress bar
"""

def __init__(self, workers: int = 16, pbar: bool = True, _pbar_position: int = 0):
super().__init__(pbar, _pbar_position)
self.num_workers = workers

@property
def result_columns(self) -> list[str]:
return [
"is_correct", "duration", "sample_rate", "error",
]

@property
def dataloader_kwargs(self) -> dict[str, Any]:
return {
"num_workers": self.num_workers,
"batch_size": 1,
"drop_last": False,
}

def preprocess_data(
self,
modality2data: ModalityToDataMapping,
metadata: dict[str, Any]
) -> Any:
return get_audio_info(modality2data['audio'], metadata, self.key_column)

def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
df_batch_labels = self._get_dict_from_schema()

for image_info in batch:
df_batch_labels[self.key_column].append(image_info.key)
df_batch_labels["is_correct"].append(image_info.is_correct)
df_batch_labels["duration"].append(image_info.duration)
df_batch_labels["sample_rate"].append(image_info.sample_rate)
df_batch_labels["error"].append(image_info.error)
return df_batch_labels
6 changes: 5 additions & 1 deletion DPF/modalities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Literal, Optional

ModalityName = Literal["image", "video", "text"]
ModalityName = Literal["image", "video", "text", "audio"]


@dataclass
Expand Down Expand Up @@ -47,6 +47,10 @@ def __repr__(self) -> str:
'video', 'video_path',
'video_name', None
),
'audio': DataModality(
'audio', 'audio_path',
'audio_name', None
),
'text': DataModality(
'text', 'text_path',
'text_name', 'text'
Expand Down
Loading