forked from deezer/spleeter
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_train.py
118 lines (103 loc) · 3.62 KB
/
test_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python
# coding: utf8
""" Unit testing for Separator class. """
__email__ = "research@deezer.com"
__author__ = "Deezer Research"
__license__ = "MIT License"
import json
import os
from os import makedirs
from os.path import join
from tempfile import TemporaryDirectory
import numpy as np
import pandas as pd # type: ignore
from typer.testing import CliRunner
from spleeter.__main__ import spleeter
from spleeter.audio.adapter import AudioAdapter
TRAIN_CONFIG = {
"mix_name": "mix",
"instrument_list": ["vocals", "other"],
"sample_rate": 44100,
"frame_length": 4096,
"frame_step": 1024,
"T": 128,
"F": 128,
"n_channels": 2,
"chunk_duration": 4,
"n_chunks_per_song": 1,
"separation_exponent": 2,
"mask_extension": "zeros",
"learning_rate": 1e-4,
"batch_size": 2,
"train_max_steps": 10,
"throttle_secs": 20,
"save_checkpoints_steps": 100,
"save_summary_steps": 5,
"random_seed": 0,
"model": {
"type": "unet.unet",
"params": {"conv_activation": "ELU", "deconv_activation": "ELU"},
},
}
def generate_fake_training_dataset(
path,
instrument_list=["vocals", "other"],
n_channels=2,
n_songs=2,
fs=44100,
duration=6,
):
"""
generates a fake training dataset in path:
- generates audio files
- generates a csv file describing the dataset
"""
aa = AudioAdapter.default()
rng = np.random.RandomState(seed=0)
dataset_df = pd.DataFrame(
columns=["mix_path"]
+ [f"{instr}_path" for instr in instrument_list]
+ ["duration"]
)
for song in range(n_songs):
song_path = join(path, "train", f"song{song}")
makedirs(song_path, exist_ok=True)
dataset_df.loc[song, "duration"] = duration
for instr in instrument_list + ["mix"]:
filename = join(song_path, f"{instr}.wav")
data = rng.rand(duration * fs, n_channels) - 0.5
aa.save(filename, data, fs)
dataset_df.loc[song, f"{instr}_path"] = join(
"train", f"song{song}", f"{instr}.wav"
)
dataset_df.to_csv(join(path, "train", "train.csv"), index=False)
def test_train():
with TemporaryDirectory() as path:
# generate training dataset
for n_channels in [1, 2]:
TRAIN_CONFIG["n_channels"] = n_channels
generate_fake_training_dataset(
path, n_channels=n_channels, fs=TRAIN_CONFIG["sample_rate"]
)
# set training command arguments
runner = CliRunner()
model_dir = join(path, f"model_{n_channels}")
train_dir = join(path, "train")
cache_dir = join(path, f"cache_{n_channels}")
TRAIN_CONFIG["train_csv"] = join(train_dir, "train.csv")
TRAIN_CONFIG["validation_csv"] = join(train_dir, "train.csv")
TRAIN_CONFIG["model_dir"] = model_dir
TRAIN_CONFIG["training_cache"] = join(cache_dir, "training")
TRAIN_CONFIG["validation_cache"] = join(cache_dir, "validation")
with open("useless_config.json", "w") as stream:
json.dump(TRAIN_CONFIG, stream)
# execute training
result = runner.invoke(
spleeter,
["train", "-p", "useless_config.json", "-d", path, "--verbose"],
)
# assert that model checkpoint was created.
assert os.path.exists(join(model_dir, "model.ckpt-10.index"))
assert os.path.exists(join(model_dir, "checkpoint"))
assert os.path.exists(join(model_dir, "model.ckpt-0.meta"))
assert result.exit_code == 0