Skip to content

Commit 4439a45

Browse files
committed
Initial commit
0 parents  commit 4439a45

File tree

8 files changed

+624
-0
lines changed

8 files changed

+624
-0
lines changed

README.md

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# chainer-WaveGlow
2+
3+
A Chainer implementation of WaveGlow( https://nv-adlr.github.io/WaveGlow ).
4+
5+
# Results
6+
I'll upload after finish the training. I'm getting audible results now. Please wait!
7+
8+
# Requirements
9+
I trained and generated with
10+
11+
- python(3.5.2)
12+
- chainer (5.0.0)
13+
- librosa (0.6.2)
14+
- matplotlib (3.0.1)
15+
16+
# Usage
17+
## download dataset
18+
You can download VCTK Corpus(en multi speaker)/LJ-Speech(en single speaker) very easily via [my repository](https://github.com/dhgrs/download_dataset).
19+
20+
## set parameters
21+
I'll write details later.
22+
23+
## training
24+
You can use same command in each directory.
25+
```
26+
(without GPU)
27+
python train.py
28+
29+
(with GPU #n)
30+
python train.py -g n
31+
```
32+
33+
You can resume snapshot and restart training like below.
34+
```
35+
python train.py -r snapshot_iter_100000
36+
```
37+
Other arguments `-f` and `-p` are parameters for multiprocess in preprocessing. `-f` means the number of prefetch and `-p` means the number of processes.
38+
39+
## generating
40+
```
41+
python generate.py -i <input file> -o <output file> -m <trained model>
42+
```
43+
44+
If you don't set `-o`, default file name `result.wav` is used. If you don't set `-s`, the speaker is same as input file that got from filepath.

WaveGlow/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import Glow
2+
from .modules import Invertible1x1Convolution
3+
from .modules import AffineCouplingLayer

WaveGlow/model.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import chainer
2+
import numpy
3+
4+
from .modules import Flow
5+
6+
7+
def _squeeze(x, squeeze_factor):
8+
batchsize, channel, length = x.shape
9+
x = x.reshape(
10+
(batchsize, channel, length // squeeze_factor, squeeze_factor))
11+
x = x.transpose((0, 1, 3, 2))
12+
x = x.reshape(
13+
(batchsize, channel * squeeze_factor, length // squeeze_factor))
14+
return x
15+
16+
17+
def _unsqueeze(x, squeeze_factor):
18+
batchsize, channel, length = x.shape
19+
x = x.reshape(
20+
(batchsize, channel // squeeze_factor, squeeze_factor, length))
21+
x = x.transpose((0, 1, 3, 2))
22+
x = x.reshape(
23+
(batchsize, channel // squeeze_factor, length * squeeze_factor))
24+
return x
25+
26+
27+
class Glow(chainer.Chain):
28+
def __init__(
29+
self, hop_length=256, n_mels=80, input_channel=1,
30+
squeeze_factor=8, n_flows=12, n_layers=8,
31+
wn_channel=512, early_every=4, early_size=2, var=0.5):
32+
super(Glow, self).__init__()
33+
self.input_channel = input_channel
34+
self.squeeze_factor = squeeze_factor
35+
self.n_flows = n_flows
36+
self.early_every = early_every
37+
self.early_size = early_size
38+
self.ln_var = float(numpy.log(var))
39+
flows = chainer.ChainList()
40+
for i in range(n_flows):
41+
flows.add_link(Flow(
42+
input_channel * squeeze_factor -
43+
early_size * (i // early_every),
44+
n_mels * squeeze_factor, n_layers, wn_channel))
45+
with self.init_scope():
46+
self.encoder = chainer.links.Deconvolution1D(
47+
n_mels, n_mels, hop_length * 4, hop_length,
48+
pad=hop_length * 3 // 2)
49+
self.flows = flows
50+
51+
def __call__(self, x, condition):
52+
_, gaussian_nll, sum_log_s, sum_log_det_W = self._forward(x, condition)
53+
loss = gaussian_nll - sum_log_s - sum_log_det_W
54+
loss += float(numpy.log(2 ** 16))
55+
chainer.reporter.report(
56+
{
57+
'gaussian_nll': gaussian_nll, 'log_s': sum_log_s,
58+
'log_det_W': sum_log_det_W, 'loss': loss}, self)
59+
return loss
60+
61+
def _forward(self, x, condition):
62+
condition = self.encoder(condition)
63+
x = _squeeze(x, self.squeeze_factor)
64+
condition = _squeeze(condition, self.squeeze_factor)
65+
sum_log_s = 0
66+
sum_log_det_W = 0
67+
outputs = []
68+
for i, flow in enumerate(self.flows.children()):
69+
x, log_s, log_det_W = flow(x, condition)
70+
if (i + 1) % self.early_every == 0:
71+
output, x = x[:, :self.early_size], x[:, self.early_size:]
72+
outputs.append(output)
73+
sum_log_s += log_s
74+
sum_log_det_W += log_det_W
75+
outputs.append(x)
76+
z = chainer.functions.concat(outputs, axis=1)
77+
gaussian_nll = chainer.functions.gaussian_nll(
78+
z,
79+
mean=self.xp.zeros_like(z, dtype=self.xp.float32),
80+
ln_var=self.ln_var * self.xp.ones_like(z, dtype=self.xp.float32)
81+
)
82+
gaussian_nll /= numpy.prod(z.shape)
83+
sum_log_s /= numpy.prod(z.shape)
84+
sum_log_det_W /= numpy.prod(z.shape)
85+
return z, gaussian_nll, sum_log_s, sum_log_det_W
86+
87+
def _reverse(self, z, condition, var=0):
88+
condition = self.encoder(condition)
89+
condition = _squeeze(condition, self.squeeze_factor)
90+
batchsize, _, length = condition.shape
91+
if z is None:
92+
z = self.xp.random.normal(
93+
0, var,
94+
(batchsize, self.input_channel * self.squeeze_factor, length))
95+
z = z.astype(self.xp.float32)
96+
_, channel, _ = z.shape
97+
start_channel = channel - \
98+
self.early_size * (self.n_flows // self.early_every)
99+
x, z = z[:, -start_channel:], z[:, :-start_channel]
100+
for i, flow in enumerate(reversed(list(self.flows.children()))):
101+
if (self.n_flows - i) % self.early_every == 0:
102+
x, z = chainer.functions.concat((
103+
z[:, -self.early_size:], x)), z[:, :-self.early_size]
104+
x = flow.reverse(x, condition)
105+
x = _unsqueeze(x, self.squeeze_factor)
106+
return x
107+
108+
def generate(self, condition, var=0.6 ** 2):
109+
return self._reverse(None, condition, var)

WaveGlow/modules.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import chainer
2+
import chainer.functions as F
3+
import chainer.links as L
4+
5+
6+
def _normalize(W):
7+
xp = chainer.cuda.get_array_module(W)
8+
g = xp.sqrt(xp.sum(W ** 2)).reshape((1,))
9+
v = W / g
10+
return g, v
11+
12+
13+
def weight_norm(link):
14+
assert hasattr(link, 'W')
15+
16+
def _W(self):
17+
return self.v * self.g
18+
19+
def _remove(self):
20+
W = _W(self)
21+
del self.g
22+
del self.v
23+
del self.W
24+
with self.init_scope():
25+
self.W = chainer.Parameter(W)
26+
27+
def _replace(args):
28+
W = _W(args.link)
29+
g, v = _normalize(_W(args.link).array)
30+
args.link.g.array[...] = g
31+
args.link.v.array[...] = v
32+
args.link.W = W
33+
34+
g, v = _normalize(link.W.array)
35+
del link.W
36+
with link.init_scope():
37+
link.g = chainer.Parameter(g)
38+
link.v = chainer.Parameter(v)
39+
40+
link.remove = _remove
41+
42+
hook = chainer.LinkHook()
43+
hook.forward_preprocess = _replace
44+
link.add_hook(hook)
45+
return link
46+
47+
48+
class Invertible1x1Convolution(chainer.link.Link):
49+
def __init__(self, channel):
50+
super(Invertible1x1Convolution, self).__init__()
51+
xp = self.xp
52+
53+
W = xp.linalg.qr(xp.random.normal(
54+
0, 1, (channel, channel)))[0].astype(xp.float32)
55+
W = W.reshape(W.shape + (1,))
56+
57+
with self.init_scope():
58+
self.W = chainer.Parameter(W)
59+
60+
@property
61+
def invW(self):
62+
return F.expand_dims(F.inv(self.W[..., 0]), axis=2)
63+
64+
def __call__(self, x):
65+
return F.convolution_1d(x, self.W), \
66+
x.shape[0] * x.shape[-1] * F.log(F.absolute(F.det(self.W[..., 0])))
67+
68+
def reverse(self, x):
69+
return F.convolution_1d(x, self.invW)
70+
71+
72+
class WaveNet(chainer.Chain):
73+
def __init__(self, out_channel, n_condition, n_layers, n_channel):
74+
super(WaveNet, self).__init__()
75+
dilated_convs = chainer.ChainList()
76+
residual_convs = chainer.ChainList()
77+
skip_convs = chainer.ChainList()
78+
condition_convs = chainer.ChainList()
79+
for i in range(n_layers):
80+
dilated_convs.add_link(weight_norm(
81+
L.Convolution1D(
82+
n_channel, 2 * n_channel, 3, pad=2 ** i, dilate=2 ** i)))
83+
residual_convs.add_link(weight_norm(
84+
L.Convolution1D(n_channel, n_channel, 1)))
85+
skip_convs.add_link(weight_norm(
86+
L.Convolution1D(n_channel, n_channel, 1)))
87+
condition_convs.add_link(weight_norm(
88+
L.Convolution1D(n_condition, 2 * n_channel, 1)))
89+
with self.init_scope():
90+
self.input_conv = weight_norm(
91+
L.Convolution1D(out_channel // 2, n_channel, 1))
92+
self.dilated_convs = dilated_convs
93+
self.residual_convs = residual_convs
94+
self.skip_convs = skip_convs
95+
self.condition_convs = condition_convs
96+
self.output_conv = L.Convolution1D(
97+
n_channel, out_channel, 1,
98+
initialW=chainer.initializers.Zero())
99+
100+
def __call__(self, x, condition):
101+
x = self.input_conv(x)
102+
skip_connection = 0
103+
for dilated, residual, skip, condition_conv in zip(
104+
self.dilated_convs, self.residual_convs, self.skip_convs,
105+
self.condition_convs):
106+
z = dilated(x) + condition_conv(condition)
107+
z_tanh, z_sigmoid = F.split_axis(z, 2, axis=1)
108+
z = F.tanh(z_tanh) * F.sigmoid(z_sigmoid)
109+
x = residual(z)
110+
skip_connection += skip(z)
111+
y = self.output_conv(skip_connection)
112+
log_s, t = F.split_axis(y, 2, axis=1)
113+
return log_s, t
114+
115+
116+
class AffineCouplingLayer(chainer.Chain):
117+
def __init__(self, *args, **kwargs):
118+
super(AffineCouplingLayer, self).__init__()
119+
with self.init_scope():
120+
self.encoder = WaveNet(*args, **kwargs)
121+
122+
def __call__(self, x, condition):
123+
x_a, x_b = F.split_axis(x, 2, axis=1)
124+
log_s, t = self.encoder(x_a, condition)
125+
x_b = F.exp(log_s) * (x_b + t)
126+
return F.concat((x_a, x_b), axis=1), F.sum(log_s)
127+
128+
def reverse(self, z, condition):
129+
x_a, x_b = F.split_axis(z, 2, axis=1)
130+
log_s, t = self.encoder(x_a, condition)
131+
x_b = x_b * F.exp(-log_s) - t
132+
return F.concat((x_a, x_b), axis=1)
133+
134+
135+
class Flow(chainer.Chain):
136+
def __init__(self, channel, n_condition, n_layers, wn_channel):
137+
super(Flow, self).__init__()
138+
with self.init_scope():
139+
self.invertible1x1convolution = Invertible1x1Convolution(
140+
channel)
141+
self.affinecouplinglayer = AffineCouplingLayer(
142+
channel, n_condition, n_layers, wn_channel)
143+
144+
def __call__(self, x, condition):
145+
x, log_det_W = self.invertible1x1convolution(x)
146+
z, log_s = self.affinecouplinglayer(x, condition)
147+
return z, log_s, log_det_W
148+
149+
def reverse(self, z, condition):
150+
z = self.affinecouplinglayer.reverse(z, condition)
151+
x = self.invertible1x1convolution.reverse(z)
152+
return x

generate.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import argparse
2+
3+
import numpy
4+
import librosa
5+
import chainer
6+
7+
from WaveGlow import Glow
8+
from utils import Preprocess
9+
import params
10+
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument('--input', '-i', help='Input file')
13+
parser.add_argument('--output', '-o', default='Result.wav', help='output file')
14+
parser.add_argument('--model', '-m', help='Snapshot of trained model')
15+
parser.add_argument('--var', '-v', type=float, default=0.6 ** 2,
16+
help='Variance of Gaussian distribution')
17+
parser.add_argument('--gpu', '-g', type=int, default=-1,
18+
help='GPU ID (negative value indicates CPU)')
19+
args = parser.parse_args()
20+
if args.gpu != [-1]:
21+
chainer.cuda.set_max_workspace_size(2 * 512 * 1024 * 1024)
22+
chainer.global_config.autotune = True
23+
24+
# set data
25+
path = args.input
26+
27+
# preprocess
28+
n = 1 # batchsize; now suporrts only 1
29+
inputs = Preprocess(
30+
params.sr, params.n_fft, params.hop_length, params.n_mels, params.fmin,
31+
params.fmax, params.top_db, None)(path)
32+
33+
_, condition = inputs
34+
condition = numpy.expand_dims(condition, axis=0)
35+
36+
# make model
37+
glow = Glow(
38+
params.hop_length, params.n_mels, 1,
39+
params.squeeze_factor, params.n_flows, params.n_layers,
40+
params.wn_channel, params.early_every, params.early_size,
41+
params.var)
42+
43+
# load trained parameter
44+
chainer.serializers.load_npz(args.model, glow, 'updater/model:main/')
45+
46+
if args.gpu >= 0:
47+
use_gpu = True
48+
chainer.cuda.get_device_from_id(args.gpu).use()
49+
else:
50+
use_gpu = False
51+
52+
# forward
53+
if use_gpu:
54+
condition = chainer.cuda.to_gpu(condition, device=args.gpu)
55+
glow.to_gpu(device=args.gpu)
56+
condition = chainer.Variable(condition)
57+
58+
with chainer.using_config('enable_backprop', False):
59+
output = glow.generate(condition)
60+
61+
output = chainer.cuda.to_cpu(output.array)
62+
output = numpy.squeeze(output)
63+
librosa.output.write_wav(args.output, output, params.sr)

0 commit comments

Comments
 (0)