Skip to content

Commit

Permalink
Add check for fail_on_decode_error and move extension lists to config (
Browse files Browse the repository at this point in the history
…#65)

* refactor can cant read extensions to config

* Add check for config.fail_on_decode_error
  • Loading branch information
benfmiller authored Sep 30, 2024
1 parent a601e8f commit 095dc53
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 43 deletions.
18 changes: 17 additions & 1 deletion audalign/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pydub.utils import mediainfo

import audalign.align as aligner
from audalign.config.fingerprint import FingerprintConfig
import audalign.datalign as datalign
import audalign.filehandler as filehandler
from audalign.config import BaseConfig
Expand Down Expand Up @@ -311,6 +312,7 @@ def write_processed_file(
start_end: tuple = None,
sample_rate: int = BaseConfig.sample_rate,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> None:
"""
writes given file to the destination file after processing for fingerprinting
Expand All @@ -329,6 +331,7 @@ def write_processed_file(
start_end=start_end,
sample_rate=sample_rate,
normalize=normalize,
cant_read_extensions=cant_read_extensions,
)


Expand Down Expand Up @@ -695,6 +698,7 @@ def write_shifts_from_results(
write_multi_channel: bool = False,
unprocessed: bool = False,
normalize: bool = BaseConfig.normalize,
config: BaseConfig = None,
):
"""
For writing the results of an alignment with alternate source files or unprocessed files
Expand All @@ -715,10 +719,12 @@ def write_shifts_from_results(
unprocessed (bool): If true, writes files without processing. For total files, only doesn't normalize
normalize (bool): if true, normalizes file when read
"""
if config is None:
config = FingerprintConfig()
if isinstance(read_from_dir, str):
print("Finding audio files")
read_from_dir = filehandler.get_audio_files_directory(
read_from_dir, full_path=True
read_from_dir, full_path=True, can_read_extensions=config.can_read_extensions, cant_read_extensions=config.cant_read_extensions
)
if read_from_dir is not None:
results_files = {}
Expand Down Expand Up @@ -767,6 +773,7 @@ def convert_audio_file(
start_end: tuple = None,
sample_rate: int = None,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
):
"""
Convert audio file to type specified in destination path
Expand All @@ -785,6 +792,7 @@ def convert_audio_file(
start_end=start_end,
sample_rate=sample_rate,
normalize=normalize,
cant_read_extensions=cant_read_extensions,
)


Expand All @@ -799,6 +807,7 @@ def uniform_level_file(
width: float = 5,
overlap_ratio: float = 0.5,
exclude_min_db: float = -70,
config: BaseConfig = FingerprintConfig(),
) -> None:
"""
Levels the file using either of two methods: normalize or average.
Expand Down Expand Up @@ -830,6 +839,7 @@ def uniform_level_file(
width=width,
overlap_ratio=overlap_ratio,
exclude_min_db=exclude_min_db,
base_config=config,
)


Expand All @@ -843,6 +853,7 @@ def uniform_level_directory(
exclude_min_db: float = -70,
multiprocessing: bool = True,
num_processors: int = None,
config: BaseConfig = FingerprintConfig(),
) -> None:
"""
Levels the file using either of two methods: normalize or average.
Expand Down Expand Up @@ -878,6 +889,7 @@ def uniform_level_directory(
exclude_min_db=exclude_min_db,
use_multiprocessing=multiprocessing,
num_processes=num_processors,
config=config,
)


Expand All @@ -889,6 +901,7 @@ def remove_noise_file(
write_extension: str = None,
alt_noise_filepath: str = None,
prop_decrease: float = 1,
config: BaseConfig = FingerprintConfig(),
**kwargs,
):
"""Remove noise from audio file by specifying start and end seconds of representative sound sections. Writes file to destination
Expand All @@ -912,6 +925,7 @@ def remove_noise_file(
write_extension=write_extension,
alt_noise_filepath=alt_noise_filepath,
prop_decrease=prop_decrease,
config=config,
**kwargs,
)

Expand All @@ -926,6 +940,7 @@ def remove_noise_directory(
prop_decrease: float = 1,
multiprocessing: bool = True,
num_processors: int = None,
config: BaseConfig = FingerprintConfig(),
**kwargs,
):
"""Remove noise from audio files in directory by specifying start and end seconds of
Expand Down Expand Up @@ -954,6 +969,7 @@ def remove_noise_directory(
prop_decrease=prop_decrease,
use_multiprocessing=multiprocessing,
num_processes=num_processors,
config=config,
**kwargs,
)

Expand Down
26 changes: 26 additions & 0 deletions audalign/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BaseConfig(ABC):
LOCALITY_SECS = "locality_seconds"

######################################################################
# rankings settings

# Add to ranking if second match is close
rankings_second_is_close_add: int = 1
Expand All @@ -71,3 +72,28 @@ class BaseConfig(ABC):
# used if rankings_get_top_num_match is not None. (used in visual)
# subtracts second value from ranking if num matches is above first value
rankings_num_matches_tups: typing.Optional[tuple] = None

######################################################################
# filehandling settings

# file types that can't be read and not explicitly filtered out by
# below extention lists will cause a crash
fail_on_decode_error = True

#
cant_write_extensions = [".mov", ".mp4", ".m4a"]
cant_read_extensions = [".txt", ".md", ".pkf", ".py", ".pyc"]
can_read_extensions = [
".mov",
".mp4",
".m4a",
".wav",
".WAV",
".mp3",
".MOV",
".ogg",
".aiff",
".aac",
".wma",
".flac",
]
56 changes: 29 additions & 27 deletions audalign/filehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,14 @@
from pydub.exceptions import CouldntDecodeError

from audalign.config import BaseConfig
from audalign.config.fingerprint import FingerprintConfig

try:
import noisereduce
except ImportError:
# Optional dependency
...


cant_write_ext = [".mov", ".mp4", ".m4a"]
cant_read_ext = [".txt", ".md", ".pkf", ".py", ".pyc"]
can_read_ext = [
".mov",
".mp4",
".m4a",
".wav",
".WAV",
".mp3",
".MOV",
".ogg",
".aiff",
".aac",
".wma",
".flac",
]

def _import_optional_dependencies(func):
@wraps(func)
def wrapper_decorator(*args, **kwargs):
Expand Down Expand Up @@ -135,7 +118,11 @@ def create_audiosegment(
return audiofile


def get_audio_files_directory(directory_path: str, full_path: bool = False) -> list:
def get_audio_files_directory(directory_path: str, full_path: bool = False,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,

) -> list:
"""returns a list of the file paths in directory that are audio
Args:
Expand All @@ -146,20 +133,24 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False) -> l
"""
aud_list = []
for file_path, ext in find_files(directory_path):
if check_is_audio_file(file_path=file_path):
if check_is_audio_file(file_path=file_path, can_read_extensions=can_read_extensions, cant_read_extensions=cant_read_extensions):
if full_path is False:
aud_list += [os.path.basename(file_path)]
else:
aud_list += [file_path]
return aud_list


def check_is_audio_file(file_path: str) -> bool:
def check_is_audio_file(
file_path: str,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> bool:
ext = os.path.splitext(file_path)[1]
try:
if ext in [".txt", ".json"] or ext in cant_read_ext:
if ext in [".txt", ".json"] or ext in cant_read_extensions:
return False
elif ext.lower() not in can_read_ext:
elif ext.lower() not in can_read_extensions:
AudioSegment.from_file(file_path)
except CouldntDecodeError:
return False
Expand All @@ -172,6 +163,7 @@ def read(
start_end: tuple = None,
sample_rate=BaseConfig.sample_rate,
normalize: bool = BaseConfig.normalize,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
):
"""
Reads any file supported by pydub (ffmpeg) and returns a numpy array and the bit depth
Expand All @@ -186,7 +178,7 @@ def read(
frame_rate (int): returns the bit depth
"""

if os.path.splitext(filename)[1] in cant_read_ext:
if os.path.splitext(filename)[1] in cant_read_extensions:
raise CouldntDecodeError
audiofile = create_audiosegment(
filename, start_end=start_end, sample_rate=sample_rate, normalize=normalize
Expand Down Expand Up @@ -220,6 +212,7 @@ def noise_remove(
write_extension: str = None,
alt_noise_filepath=None,
prop_decrease=1,
config: BaseConfig = FingerprintConfig(),
**kwargs,
):
audiofile = create_audiosegment(filepath)
Expand Down Expand Up @@ -278,6 +271,7 @@ def noise_remove_directory(
prop_decrease=1,
use_multiprocessing=False,
num_processes=None,
config: BaseConfig = FingerprintConfig(),
**kwargs,
):
noise_data = _floatify_data(create_audiosegment(noise_filepath))[
Expand All @@ -293,6 +287,7 @@ def noise_remove_directory(
destination_directory=destination_directory,
prop_decrease=prop_decrease,
write_extension=write_extension,
base_config=config,
**kwargs,
)

Expand Down Expand Up @@ -323,6 +318,7 @@ def _remove_noise(
write_extension: str = None,
destination_directory="",
prop_decrease=1,
base_config: BaseConfig = FingerprintConfig(),
**kwargs,
):

Expand All @@ -343,7 +339,7 @@ def _remove_noise(

file_name = os.path.basename(file_path)
destination_name = os.path.join(destination_directory, file_name)
if os.path.splitext(destination_name)[1].lower() in cant_write_ext:
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down Expand Up @@ -402,6 +398,7 @@ def uniform_level_directory(
exclude_min_db=-70,
use_multiprocessing=False,
num_processes=None,
config: BaseConfig = FingerprintConfig(),
):
_uniform_level_ = partial(
_uniform_level,
Expand All @@ -411,6 +408,7 @@ def uniform_level_directory(
width=width,
overlap_ratio=overlap_ratio,
exclude_min_db=exclude_min_db,
base_config=config,
)

if use_multiprocessing == True:
Expand Down Expand Up @@ -441,6 +439,7 @@ def _uniform_level(
width: float = 5,
overlap_ratio=0.5,
exclude_min_db=-70,
base_config: BaseConfig = FingerprintConfig(),
):
assert overlap_ratio < 1 and overlap_ratio >= 0
try:
Expand Down Expand Up @@ -481,7 +480,7 @@ def _uniform_level(
file_name = os.path.basename(file_path)
if len(os.path.splitext(destination_name)[1]) == 0:
destination_name = os.path.join(destination_name, file_name)
if os.path.splitext(destination_name)[1].lower() in cant_write_ext:
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down Expand Up @@ -656,6 +655,7 @@ def _shift_write_separate(
return_files: bool = False,
unprocessed: bool = False,
normalize: bool = BaseConfig.normalize,
base_config: BaseConfig = FingerprintConfig(),
):
audsegs = _shift_prepend_space_audsegs(
files_shifts=files_shifts,
Expand All @@ -673,6 +673,7 @@ def _shift_write_separate(
file_path=file_path,
destination_path=destination_path,
write_extension=write_extension,
base_config=base_config,
)

audsegs = list(audsegs.values())
Expand Down Expand Up @@ -754,12 +755,13 @@ def _write_single_shift(
file_path: str,
destination_path: str,
write_extension: str,
base_config: BaseConfig = FingerprintConfig(),
):

file_name = os.path.basename(file_path)
destination_name = os.path.join(destination_path, file_name) # type: ignore

if os.path.splitext(destination_name)[1] in cant_write_ext:
if os.path.splitext(destination_name)[1] in base_config.cant_write_extensions:
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension:
Expand Down
12 changes: 11 additions & 1 deletion audalign/recognizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import audalign.filehandler as filehandler
from audalign.config import BaseConfig
from pydub.exceptions import CouldntDecodeError


class BaseRecognizer(ABC):
Expand Down Expand Up @@ -46,10 +47,19 @@ def align_get_file_names(
if target_aligning:
file_names = [os.path.basename(x) for x in file_list]
elif file_dir:
file_names = filehandler.get_audio_files_directory(file_dir)
file_names = filehandler.get_audio_files_directory(
file_dir,
False,
self.config.can_read_extensions,
self.config.cant_read_extensions,
)
elif fine_aud_file_dict:
if fine_aud_file_dict == None or len(fine_aud_file_dict.keys()) == 0:
raise CouldntDecodeError("No files found", fine_aud_file_dict)
file_names = [os.path.basename(x) for x in fine_aud_file_dict.keys()]
else:
if file_list == None or len(file_list) == 0:
raise CouldntDecodeError("No files found", file_list)
file_names = [os.path.basename(x) for x in file_list]
return file_names

Expand Down
Loading

0 comments on commit 095dc53

Please sign in to comment.