Skip to content

Commit 443fe30

Browse files
authored
Merge pull request #9 from ericguizzo/master
Add Docker environment & web demo
2 parents b4d8d01 + ee58477 commit 443fe30

File tree

3 files changed

+170
-2
lines changed

3 files changed

+170
-2
lines changed

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Wave-U-Net (Pytorch)
2+
<a href="https://replicate.ai/f90/wave-u-net-pytorch"><img src="https://img.shields.io/static/v1?label=Replicate&message=Demo and Docker Image&color=darkgreen" height=20></a>
23

34
Improved version of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation, implemented in Pytorch.
45

56
Click [here](www.github.com/f90/Wave-U-Net) for the original Wave-U-Net implementation in Tensorflow.
6-
You can find more information about the model and results there as well.
7+
You can find more information about the model and results there as well.
78

89
# Improvements
910

@@ -24,7 +25,9 @@ GPU strongly recommended to avoid very long training times.
2425
System requirements:
2526
* Linux-based OS
2627
* Python 3.6
28+
2729
* [libsndfile](http://mega-nerd.com/libsndfile/)
30+
2831
* [ffmpeg](https://www.ffmpeg.org/)
2932
* CUDA 10.1 for GPU usage
3033

@@ -68,6 +71,7 @@ You can of course use your own datasets for training, but for this you would nee
6871
# Training the models
6972

7073
To train a Wave-U-Net, the basic command to use is
74+
7175
```
7276
python3.6 train.py --dataset_dir /PATH/TO/MUSDB18HQ
7377
```
@@ -86,7 +90,7 @@ After training, the model is evaluated on the MUSDB18HQ test set, and SDR/SIR/SA
8690

8791
# <a name="test"></a> Test trained models on songs!
8892

89-
We provide the default model in a pre-trained form as download so you can separate your own songs right away.
93+
We provide the default model in a pre-trained form as download so you can separate your own songs right away.
9094

9195
## Downloading our pretrained models
9296

cog.yaml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
build:
2+
python_version: "3.6"
3+
gpu: false
4+
python_packages:
5+
- future==0.18.2
6+
- numpy==1.19.5
7+
- librosa==0.8.1
8+
- soundfile==0.10.3.post1
9+
- musdb==0.4.0
10+
- museval==0.4.0
11+
- h5py==3.1.0
12+
- tqdm==4.62.1
13+
- torch==1.4.0
14+
- torchvision==0.5.0
15+
- tensorboard==2.6.0
16+
- sortedcontainers==2.4.0
17+
system_packages:
18+
- libsndfile-dev
19+
- ffmpeg
20+
predict: "cog_predict.py:waveunetPredictor"

cog_predict.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import os
2+
import cog
3+
import tempfile
4+
import zipfile
5+
from pathlib import Path
6+
import argparse
7+
import data.utils
8+
import model.utils as model_utils
9+
from test import predict_song
10+
from model.waveunet import Waveunet
11+
12+
13+
class waveunetPredictor(cog.Predictor):
14+
def setup(self):
15+
"""Init wave u net model"""
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"--instruments",
19+
type=str,
20+
nargs="+",
21+
default=["bass", "drums", "other", "vocals"],
22+
help='List of instruments to separate (default: "bass drums other vocals")',
23+
)
24+
parser.add_argument(
25+
"--cuda", action="store_true", help="Use CUDA (default: False)"
26+
)
27+
parser.add_argument(
28+
"--features",
29+
type=int,
30+
default=32,
31+
help="Number of feature channels per layer",
32+
)
33+
parser.add_argument(
34+
"--load_model",
35+
type=str,
36+
default="checkpoints/waveunet/model",
37+
help="Reload a previously trained model",
38+
)
39+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
40+
parser.add_argument(
41+
"--levels", type=int, default=6, help="Number of DS/US blocks"
42+
)
43+
parser.add_argument(
44+
"--depth", type=int, default=1, help="Number of convs per block"
45+
)
46+
parser.add_argument("--sr", type=int, default=44100, help="Sampling rate")
47+
parser.add_argument(
48+
"--channels", type=int, default=2, help="Number of input audio channels"
49+
)
50+
parser.add_argument(
51+
"--kernel_size",
52+
type=int,
53+
default=5,
54+
help="Filter width of kernels. Has to be an odd number",
55+
)
56+
parser.add_argument(
57+
"--output_size", type=float, default=2.0, help="Output duration"
58+
)
59+
parser.add_argument(
60+
"--strides", type=int, default=4, help="Strides in Waveunet"
61+
)
62+
parser.add_argument(
63+
"--conv_type",
64+
type=str,
65+
default="gn",
66+
help="Type of convolution (normal, BN-normalised, GN-normalised): normal/bn/gn",
67+
)
68+
parser.add_argument(
69+
"--res",
70+
type=str,
71+
default="fixed",
72+
help="Resampling strategy: fixed sinc-based lowpass filtering or learned conv layer: fixed/learned",
73+
)
74+
parser.add_argument(
75+
"--separate",
76+
type=int,
77+
default=1,
78+
help="Train separate model for each source (1) or only one (0)",
79+
)
80+
parser.add_argument(
81+
"--feature_growth",
82+
type=str,
83+
default="double",
84+
help="How the features in each layer should grow, either (add) the initial number of features each time, or multiply by 2 (double)",
85+
)
86+
"""
87+
parser.add_argument('--input', type=str, default=str(input),
88+
help="Path to input mixture to be separated")
89+
parser.add_argument('--output', type=str, default=out_path, help="Output path (same folder as input path if not set)")
90+
"""
91+
args = parser.parse_args([])
92+
self.args = args
93+
94+
num_features = (
95+
[args.features * i for i in range(1, args.levels + 1)]
96+
if args.feature_growth == "add"
97+
else [args.features * 2 ** i for i in range(0, args.levels)]
98+
)
99+
target_outputs = int(args.output_size * args.sr)
100+
self.model = Waveunet(
101+
args.channels,
102+
num_features,
103+
args.channels,
104+
args.instruments,
105+
kernel_size=args.kernel_size,
106+
target_output_size=target_outputs,
107+
depth=args.depth,
108+
strides=args.strides,
109+
conv_type=args.conv_type,
110+
res=args.res,
111+
separate=args.separate,
112+
)
113+
114+
if args.cuda:
115+
self.model = model_utils.DataParallel(model)
116+
print("move model to gpu")
117+
self.model.cuda()
118+
119+
print("Loading model from checkpoint " + str(args.load_model))
120+
state = model_utils.load_model(self.model, None, args.load_model, args.cuda)
121+
print("Step", state["step"])
122+
123+
@cog.input("input", type=Path, help="audio mixture path")
124+
def predict(self, input):
125+
"""Separate tracks from input mixture audio"""
126+
127+
out_path = Path(tempfile.mkdtemp())
128+
zip_path = Path(tempfile.mkdtemp()) / "output.zip"
129+
130+
preds = predict_song(self.args, input, self.model)
131+
132+
out_names = []
133+
for inst in preds.keys():
134+
temp_n = os.path.join(
135+
str(out_path), os.path.basename(str(input)) + "_" + inst + ".wav"
136+
)
137+
data.utils.write_wav(temp_n, preds[inst], self.args.sr)
138+
out_names.append(temp_n)
139+
140+
with zipfile.ZipFile(str(zip_path), "w") as zf:
141+
for i in out_names:
142+
zf.write(str(i))
143+
144+
return zip_path

0 commit comments

Comments
 (0)