Skip to content

Commit

Permalink
formatting with black
Browse files Browse the repository at this point in the history
  • Loading branch information
benfmiller committed Sep 22, 2024
1 parent e5e30b2 commit d42152d
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 46 deletions.
15 changes: 12 additions & 3 deletions audalign/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def wrapper_decorator(*args, **kwargs):
results = func(*args, **kwargs)
if results is None:
return results
assert results.get("rankings") is not None #This should run after rankings are added
assert (
results.get("rankings") is not None
) # This should run after rankings are added
close_seconds_filter = BaseConfig.close_seconds_filter
if kwargs.get("recognizer") is not None:
close_seconds_filter = kwargs.get("recognizer").config.close_seconds_filter
Expand All @@ -73,8 +75,10 @@ def wrapper_decorator(*args, **kwargs):
return __filter_close_seconds_alignment(results, close_seconds_filter)
else:
return __filter_close_seconds(results, close_seconds_filter)

return wrapper_decorator


@filter_close_seconds
@add_rankings
def recognize(
Expand Down Expand Up @@ -542,6 +546,7 @@ def pretty_print_alignment(results, match_keys="both"):
print("No Matches Found")
print()


def __filter_close_seconds(results: dict, close_seconds_filter: float):
results_iterable_keys = []
# all list items in against_filename dictionary values
Expand All @@ -559,7 +564,7 @@ def __filter_close_seconds(results: dict, close_seconds_filter: float):
match = False
for unfiltered_val in unfiltered_offset_seconds:
if abs(abs(val) - abs(unfiltered_val)) <= close_seconds_filter:
iter_index_pop.append(i+1)
iter_index_pop.append(i + 1)
match = True
break
if not match:
Expand All @@ -571,6 +576,7 @@ def __filter_close_seconds(results: dict, close_seconds_filter: float):
against_dict[key] = temp_list
return results


def __filter_close_seconds_alignment(results: dict, close_seconds_filter: float):
match_keys = ["match_info"]
if results.get("fine_match_info") is not None:
Expand Down Expand Up @@ -724,7 +730,10 @@ def write_shifts_from_results(
if isinstance(read_from_dir, str):
print("Finding audio files")
read_from_dir = filehandler.get_audio_files_directory(
read_from_dir, full_path=True, can_read_extensions=config.can_read_extensions, cant_read_extensions=config.cant_read_extensions
read_from_dir,
full_path=True,
can_read_extensions=config.can_read_extensions,
cant_read_extensions=config.cant_read_extensions,
)
if read_from_dir is not None:
results_files = {}
Expand Down
1 change: 0 additions & 1 deletion audalign/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class BaseConfig(ABC):
# below extention lists will cause a crash
fail_on_decode_error = True

#
cant_write_extensions = [".mov", ".mp4", ".m4a"]
cant_read_extensions = [".txt", ".md", ".pkf", ".py", ".pyc"]
can_read_extensions = [
Expand Down
2 changes: 1 addition & 1 deletion audalign/config/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ class CorrelationConfig(BaseConfig):
(0.65, -3),
(0.1, -4),
(0.0, 0),
)
)
2 changes: 1 addition & 1 deletion audalign/config/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ class VisualConfig(BaseConfig):
(20, 2),
(30, 1),
(99999999999, 0),
)
)
9 changes: 3 additions & 6 deletions audalign/datalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,10 @@ def event_a_is_closer(offset_a: int, offset_b: int) -> bool:
return offset_a > offset_b


def distance_from_event():
...
def distance_from_event(): ...


def angle_two_events():
...
def angle_two_events(): ...


def which_is_first():
...
def which_is_first(): ...
36 changes: 24 additions & 12 deletions audalign/filehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# Optional dependency
...


def _import_optional_dependencies(func):
@wraps(func)
def wrapper_decorator(*args, **kwargs):
Expand Down Expand Up @@ -118,11 +119,12 @@ def create_audiosegment(
return audiofile


def get_audio_files_directory(directory_path: str, full_path: bool = False,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,

) -> list:
def get_audio_files_directory(
directory_path: str,
full_path: bool = False,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> list:
"""returns a list of the file paths in directory that are audio
Args:
Expand All @@ -133,7 +135,11 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False,
"""
aud_list = []
for file_path, ext in find_files(directory_path):
if check_is_audio_file(file_path=file_path, can_read_extensions=can_read_extensions, cant_read_extensions=cant_read_extensions):
if check_is_audio_file(
file_path=file_path,
can_read_extensions=can_read_extensions,
cant_read_extensions=cant_read_extensions,
):
if full_path is False:
aud_list += [os.path.basename(file_path)]
else:
Expand All @@ -142,10 +148,10 @@ def get_audio_files_directory(directory_path: str, full_path: bool = False,


def check_is_audio_file(
file_path: str,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> bool:
file_path: str,
can_read_extensions: list[str] = BaseConfig.can_read_extensions,
cant_read_extensions: list[str] = BaseConfig.cant_read_extensions,
) -> bool:
ext = os.path.splitext(file_path)[1]
try:
if ext in [".txt", ".json"] or ext in cant_read_extensions:
Expand Down Expand Up @@ -336,7 +342,10 @@ def _remove_noise(

file_name = os.path.basename(file_path)
destination_name = os.path.join(destination_directory, file_name)
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
if (
os.path.splitext(destination_name)[1].lower()
in base_config.cant_write_extensions
):
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down Expand Up @@ -475,7 +484,10 @@ def _uniform_level(
file_name = os.path.basename(file_path)
if len(os.path.splitext(destination_name)[1]) == 0:
destination_name = os.path.join(destination_name, file_name)
if os.path.splitext(destination_name)[1].lower() in base_config.cant_write_extensions:
if (
os.path.splitext(destination_name)[1].lower()
in base_config.cant_write_extensions
):
destination_name = os.path.splitext(destination_name)[0] + ".wav"

if write_extension is not None:
Expand Down
6 changes: 5 additions & 1 deletion audalign/recognizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def align_get_file_names(
file_names = [os.path.basename(x) for x in file_list]
elif file_dir:
file_names = filehandler.get_audio_files_directory(
file_dir, False, self.config.can_read_extensions, self.config.cant_read_extensions)
file_dir,
False,
self.config.can_read_extensions,
self.config.cant_read_extensions,
)
elif fine_aud_file_dict:
file_names = [os.path.basename(x) for x in fine_aud_file_dict.keys()]
else:
Expand Down
26 changes: 19 additions & 7 deletions audalign/recognizers/fingerprint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def align_get_file_names(
fine_aud_file_dict: typing.Optional[dict],
) -> list:
if target_aligning or file_dir:
file_names = filehandler.get_audio_files_directory(file_dir, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions)
file_names = filehandler.get_audio_files_directory(
file_dir,
full_path=True,
can_read_extensions=self.config.can_read_extensions,
cant_read_extensions=self.config.cant_read_extensions,
)
elif fine_aud_file_dict:
file_names = fine_aud_file_dict.keys()
for name, fingerprints in zip(self.file_names, self.fingerprinted_files):
Expand Down Expand Up @@ -172,14 +177,19 @@ def recognize(
if against_path is not None:
if os.path.isdir(against_path):
for path in filehandler.get_audio_files_directory(
against_path, full_path=True, can_read_extensions=self.config.can_read_extensions, cant_read_extensions=self.config.cant_read_extensions
against_path,
full_path=True,
can_read_extensions=self.config.can_read_extensions,
cant_read_extensions=self.config.cant_read_extensions,
):
if path not in self.file_names and path not in to_fingerprint:
to_fingerprint += [path]
elif os.path.isfile(against_path):
if filehandler.check_is_audio_file(against_path,
self.config.can_read_extensions,
self.config.cant_read_extensions):
if filehandler.check_is_audio_file(
against_path,
self.config.can_read_extensions,
self.config.cant_read_extensions,
):
to_fingerprint += [against_path]
if len(to_fingerprint) > 0:
self.fingerprint_directory(to_fingerprint)
Expand Down Expand Up @@ -268,7 +278,9 @@ def _fingerprint_directory(
else:
print("Directory contains 0 files or could not be found")
if self.config.fail_on_decode_error:
raise CouldntDecodeError("Directory contains 0 files or could not be found")
raise CouldntDecodeError(
"Directory contains 0 files or could not be found"
)
return

if _file_audsegs is not None:
Expand Down Expand Up @@ -306,7 +318,7 @@ def _fingerprint_directory(
result = []

for filename in filenames_to_fingerprint:
if isinstance(filename, str): # fine alignments are tuples with offsets
if isinstance(filename, str): # fine alignments are tuples with offsets
file_name = os.path.basename(filename)
if file_name in self.file_names:
print(f"{file_name} already fingerprinted, continuing...")
Expand Down
4 changes: 3 additions & 1 deletion audalign/recognizers/fingerprint/fingerprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def _fingerprint_worker(
if config.fail_on_decode_error:
raise e
return None, None
except IndexError: # Pydub throws IndexErrors for some files on Ubuntu (json, txt, others?)
except (
IndexError
): # Pydub throws IndexErrors for some files on Ubuntu (json, txt, others?)
print(f'File "{file_name}" could not be decoded')
return None, None
elif type(file_path) == tuple:
Expand Down
2 changes: 2 additions & 0 deletions audalign/recognizers/visrecognize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from functools import partial
from functools import wraps


def _import_optional_dependencies(func):
@wraps(func)
def wrapper_decorator(*args, **kwargs):
Expand All @@ -23,6 +24,7 @@ def wrapper_decorator(*args, **kwargs):

return wrapper_decorator


class VisualRecognizer(BaseRecognizer):
config: VisualConfig

Expand Down
1 change: 1 addition & 0 deletions audalign/recognizers/visrecognize/visrecognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from audalign.filehandler import find_files, get_shifted_file, read
from PIL import Image
from pydub.exceptions import CouldntDecodeError

try:
from skimage.metrics import mean_squared_error
from skimage.metrics import structural_similarity as ssim
Expand Down
30 changes: 23 additions & 7 deletions tests/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
test_folder_eig = "test_audio/test_shifts/"


def ensure_close_seconds_filter(result, close_seconds_filter, initial_filter="match_info"):
def ensure_close_seconds_filter(
result, close_seconds_filter, initial_filter="match_info"
):
for target_file in list(result.get(initial_filter).values()):
for against_file in list(target_file["match_info"].values()):
offset_list = sorted(against_file["offset_seconds"])
start = offset_list[0]
for i in offset_list[1:]:
assert i - start > close_seconds_filter # results within close_seconds_filter
assert (
i - start > close_seconds_filter
) # results within close_seconds_filter
start = i


class TestAlign:
fingerprint_recognizer = ad.FingerprintRecognizer(
load_fingerprints_file="tests/test_fingerprints.json"
Expand Down Expand Up @@ -103,7 +108,9 @@ def test_align_cor_spec_options(self, tmpdir):
assert result is not None
ad.pretty_print_alignment(result)

@pytest.mark.skipif(skimage is None, reason="visrecognize optional dependencies not installed")
@pytest.mark.skipif(
skimage is None, reason="visrecognize optional dependencies not installed"
)
def test_align_vis(self, tmpdir):
recognizer = ad.VisualRecognizer()
recognizer.config.volume_threshold = 214
Expand Down Expand Up @@ -153,7 +160,9 @@ def test_align_files_load_fingerprints(self):
)
assert result

@pytest.mark.skipif(skimage is None, reason="visrecognize optional dependencies not installed")
@pytest.mark.skipif(
skimage is None, reason="visrecognize optional dependencies not installed"
)
def test_align_files_vis(self, tmpdir):
recognizer = ad.VisualRecognizer()
recognizer.config.volume_threshold = 214
Expand Down Expand Up @@ -220,8 +229,11 @@ def test_align_close_seconds_filter(self, tmpdir):
assert result
ensure_close_seconds_filter(result, close_seconds_filter)


class TestTargetAlign:
@pytest.mark.skipif(skimage is None, reason="visrecognize optional dependencies not installed")
@pytest.mark.skipif(
skimage is None, reason="visrecognize optional dependencies not installed"
)
def test_target_align_vis(self, tmpdir):
recognizer = ad.VisualRecognizer()
recognizer.config.volume_threshold = 214
Expand All @@ -242,7 +254,9 @@ def test_target_align_vis(self, tmpdir):
)
assert result is not None

@pytest.mark.skipif(skimage is None, reason="visrecognize optional dependencies not installed")
@pytest.mark.skipif(
skimage is None, reason="visrecognize optional dependencies not installed"
)
def test_target_align_vis_mse(self, tmpdir):
recognizer = ad.VisualRecognizer()
recognizer.config.volume_threshold = 214
Expand Down Expand Up @@ -361,7 +375,9 @@ def test_fine_align_load_fingerprints(self):
assert result is not None
ad.pretty_print_alignment(result, match_keys="match_info")

@pytest.mark.skipif(skimage is None, reason="visrecognize optional dependencies not installed")
@pytest.mark.skipif(
skimage is None, reason="visrecognize optional dependencies not installed"
)
def test_fine_align_visual(self, tmpdir):
recognizer = ad.VisualRecognizer()
recognizer.config.volume_threshold = 210
Expand Down
5 changes: 4 additions & 1 deletion tests/test_audalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
except ImportError:
noisereduce = None


def test_always_true():
assert True

Expand Down Expand Up @@ -166,7 +167,9 @@ def test_uniform_level_file_average(self, tmpdir):
)


@pytest.mark.skipif(noisereduce is None, reason="noisereduce optional dependencies not installed")
@pytest.mark.skipif(
noisereduce is None, reason="noisereduce optional dependencies not installed"
)
class TestRemoveNoise:
test_file = "test_audio/testers/test.mp3"

Expand Down
Loading

0 comments on commit d42152d

Please sign in to comment.