diff --git a/.gitignore b/.gitignore index 3660fc02..a8ba96bc 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,5 @@ __pycache__ pretrained_models docs/build .vscode -spleeter-feedstock/ \ No newline at end of file +spleeter-feedstock/ +*FAKE_MUSDB_DIR \ No newline at end of file diff --git a/spleeter/commands/evaluate.py b/spleeter/commands/evaluate.py index 9de330cb..1e1e3f17 100644 --- a/spleeter/commands/evaluate.py +++ b/spleeter/commands/evaluate.py @@ -163,3 +163,5 @@ def entrypoint(arguments, params): get_logger().info('%s:', instrument) for metric, value in metric.items(): get_logger().info('%s: %s', metric, f'{np.median(value):.3f}') + + return metrics diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 00000000..257f75fb --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# coding: utf8 + +""" Unit testing for Separator class. """ + +__email__ = 'research@deezer.com' +__author__ = 'Deezer Research' +__license__ = 'MIT License' + +import filecmp +import itertools +from os import makedirs +from os.path import splitext, basename, exists, join +from tempfile import TemporaryDirectory + +import pytest +import numpy as np + +import tensorflow as tf + +from spleeter.audio.adapter import get_default_audio_adapter +from spleeter.commands import create_argument_parser + +from spleeter.commands import evaluate + +from spleeter.utils.configuration import load_configuration + +res_4stems = { "vocals": { + "SDR": -0.009, + "SAR": -19.044, + "SIR": -4.072, + "ISR": -0.000 + }, + "drums": { + "SDR": -0.066, + "SAR": -14.294, + "SIR": -4.908, + "ISR": 0.002 + }, + "bass":{ + "SDR": -0.000, + "SAR": -6.364, + "SIR": -9.396, + "ISR": -0.001 + }, + "other":{ + "SDR": -1.464, + "SAR": -14.893, + "SIR": -4.762, + "ISR": -0.027 + } + } + + +def generate_fake_eval_dataset(path): + aa = get_default_audio_adapter() + n_songs = 2 + fs = 44100 + duration = 3 + n_channels = 2 + for song in range(n_songs): + song_path = join(path, "test", f"song{song}") + makedirs(song_path, exist_ok=True) + rng = np.random.RandomState(seed=0) + for instr in ["mixture", "vocals", "bass", "drums", "other"]: + filename = join(song_path, f"{instr}.wav") + data = rng.rand(duration*fs, n_channels)-0.5 + aa.save(filename, data, fs) + + +def test_evaluate(path="FAKE_MUSDB_DIR"): + generate_fake_eval_dataset(path) + p = create_argument_parser() + arguments = p.parse_args(["evaluate", "-p", "spleeter:4stems", "--mus_dir", path]) + params = load_configuration(arguments.configuration) + metrics = evaluate.entrypoint(arguments, params) + for instrument, metric in metrics.items(): + for metric, value in metric.items(): + assert np.allclose(np.median(value), res_4stems[instrument][metric], atol=1e-3)