Skip to content

Commit 86a1e5c

Browse files
authored
Merge pull request #109 from TheButlah/master
Enabled multi-gpu training, buffers, grad clip in vocoder, saving optimizer state, and more fixes
2 parents 173ec17 + 5d4ead9 commit 86a1e5c

13 files changed

+204
-64
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# PyCharm files
1+
# IDE files
22
.idea
3+
.vscode
34

45
# Mac files
56
.DS_Store

gen_tacotron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191

9292
simple_table([('WaveRNN', str(voc_k) + 'k'),
9393
('Tacotron', str(tts_k) + 'k'),
94-
('r', tts_model.r.item()),
94+
('r', tts_model.r),
9595
('Generation Mode', 'Batched' if batched else 'Unbatched'),
9696
('Target Samples', target if batched else 'N/A'),
9797
('Overlap Samples', overlap if batched else 'N/A')])

gen_wavernn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import argparse
88

99

10-
def gen_testset(model, test_set, samples, batched, target, overlap, save_path):
10+
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path):
1111

1212
k = model.get_step() // 1000
1313

@@ -34,7 +34,7 @@ def gen_testset(model, test_set, samples, batched, target, overlap, save_path):
3434
_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
3535

3636

37-
def gen_from_file(model, load_path, save_path, batched, target, overlap):
37+
def gen_from_file(model: WaveRNN, load_path, save_path, batched, target, overlap):
3838

3939
k = model.get_step() // 1000
4040
file_name = load_path.split('/')[-1]
@@ -61,7 +61,7 @@ def gen_from_file(model, load_path, save_path, batched, target, overlap):
6161
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
6262
parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset')
6363
parser.add_argument('--weights', '-w', type=str, help='[string/path] checkpoint file to load weights from')
64-
parser.add_argument('--gta', '-g', dest='use_gta', action='store_true', help='Generate from GTA testset')
64+
parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset')
6565
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
6666

6767
parser.set_defaults(batched=hp.voc_gen_batched)

hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
voc_test_samples = 50 # How many unseen samples to put aside for testing
5353
voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length
5454
voc_seq_len = hop_length * 5 # must be a multiple of hop_length
55+
voc_clip_grad_norm = 4 # set to None if no gradient clipping needed
5556

5657
# Generating / Synthesizing
5758
voc_gen_batched = True # very fast (realtime+) single utterance batched generation

models/deepmind_version.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
from utils.display import *
55
from utils.dsp import *
6+
import numpy as np
67

78
class WaveRNN(nn.Module):
89
def __init__(self, hidden_size=896, quantisation=256):
@@ -167,7 +168,9 @@ def get_initial_hidden(self, batch_size=1):
167168
device = next(self.parameters()).device # use same device as parameters
168169
return torch.zeros(batch_size, self.hidden_size, device=device)
169170

170-
def num_params(self):
171+
def num_params(self, print_out=True):
171172
parameters = filter(lambda p: p.requires_grad, self.parameters())
172173
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
173-
print('Trainable Parameters: %.3f million' % parameters)
174+
if print_out:
175+
print('Trainable Parameters: %.3f million' % parameters)
176+
return parameters

models/fatchord_version.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from utils.display import *
66
from utils.dsp import *
77
import os
8+
import numpy as np
89

910

1011
class ResBlock(nn.Module):
@@ -100,24 +101,38 @@ def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
100101
else:
101102
RuntimeError("Unknown model mode value - ", self.mode)
102103

104+
# List of rnns to call `flatten_parameters()` on
105+
self._to_flatten = []
106+
103107
self.rnn_dims = rnn_dims
104108
self.aux_dims = res_out_dims // 4
105109
self.hop_length = hop_length
106110
self.sample_rate = sample_rate
107111

108112
self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
109113
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
114+
110115
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
111116
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
117+
self._to_flatten += [self.rnn1, self.rnn2]
118+
112119
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
113120
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
114121
self.fc3 = nn.Linear(fc_dims, self.n_classes)
115122

116-
self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
123+
self.register_buffer('step', torch.zeros(1, dtype=torch.long))
117124
self.num_params()
118125

126+
# Avoid fragmentation of RNN parameters and associated warning
127+
self._flatten_parameters()
128+
119129
def forward(self, x, mels):
120130
device = next(self.parameters()).device # use same device as parameters
131+
132+
# Although we `_flatten_parameters()` on init, when using DataParallel
133+
# the model gets replicated, making it no longer guaranteed that the
134+
# weights are contiguous in GPU memory. Hence, we must call it again
135+
self._flatten_parameters()
121136

122137
self.step += 1
123138
bsize = x.size(0)
@@ -226,14 +241,14 @@ def generate(self, mels, save_path, batched, target, overlap, mu_law):
226241
output = output.cpu().numpy()
227242
output = output.astype(np.float64)
228243

244+
if mu_law:
245+
output = decode_mu_law(output, self.n_classes, False)
246+
229247
if batched:
230248
output = self.xfade_and_unfold(output, target, overlap)
231249
else:
232250
output = output[0]
233251

234-
if mu_law:
235-
output = decode_mu_law(output, self.n_classes, False)
236-
237252
# Fade-out at the end to avoid signal cutting out suddenly
238253
fade_out = np.linspace(1, 0, 20 * self.hop_length)
239254
output = output[:wave_len]
@@ -388,9 +403,12 @@ def xfade_and_unfold(self, y, target, overlap):
388403
def get_step(self):
389404
return self.step.data.item()
390405

391-
def checkpoint(self, path):
406+
def checkpoint(self, path, optimizer):
407+
# Optimizer can be given as an argument because checkpoint function is
408+
# only useful in context of already existing training process.
392409
k_steps = self.get_step() // 1000
393410
self.save(f'{path}/checkpoint_{k_steps}k_steps.pyt')
411+
torch.save(optimizer.get_state(), f'{path}/checkpoint_{k_steps}k_steps_optim.pyt')
394412

395413
def log(self, path, msg):
396414
with open(path, 'a') as f:
@@ -404,15 +422,25 @@ def restore(self, path):
404422
print(f'\nLoading Weights: "{path}"\n')
405423
self.load(path)
406424

407-
def load(self, path, device='cpu'):
408-
# because PyTorch places on CPU by default, we follow those semantics by using CPU as default.
425+
def load(self, path):
426+
# Use device of model params as location for loaded state
427+
device = next(self.parameters()).device
409428
self.load_state_dict(torch.load(path, map_location=device), strict=False)
410429

411430
def save(self, path):
431+
# No optimizer argument because saving a model should not include data
432+
# only relevant in the training process - it should only be properties
433+
# of the model itself. Let caller take care of saving optimzier state.
412434
torch.save(self.state_dict(), path)
413435

414436
def num_params(self, print_out=True):
415437
parameters = filter(lambda p: p.requires_grad, self.parameters())
416438
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
417439
if print_out:
418440
print('Trainable Parameters: %.3fM' % parameters)
441+
return parameters
442+
443+
def _flatten_parameters(self):
444+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
445+
to improve efficiency and avoid PyTorch yelling at us."""
446+
[m.flatten_parameters() for m in self._to_flatten]

models/tacotron.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ class CBHG(nn.Module):
5454
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
5555
super().__init__()
5656

57+
# List of all rnns to call `flatten_parameters()` on
58+
self._to_flatten = []
59+
5760
self.bank_kernels = [i for i in range(1, K + 1)]
5861
self.conv1d_bank = nn.ModuleList()
5962
for k in self.bank_kernels:
@@ -78,8 +81,16 @@ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
7881
self.highways.append(hn)
7982

8083
self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
84+
self._to_flatten.append(self.rnn)
85+
86+
# Avoid fragmentation of RNN parameters and associated warning
87+
self._flatten_parameters()
8188

8289
def forward(self, x):
90+
# Although we `_flatten_parameters()` on init, when using DataParallel
91+
# the model gets replicated, making it no longer guaranteed that the
92+
# weights are contiguous in GPU memory. Hence, we must call it again
93+
self._flatten_parameters()
8394

8495
# Save these for later
8596
residual = x
@@ -114,6 +125,10 @@ def forward(self, x):
114125
x, _ = self.rnn(x)
115126
return x
116127

128+
def _flatten_parameters(self):
129+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
130+
to improve efficiency and avoid PyTorch yelling at us."""
131+
[m.flatten_parameters() for m in self._to_flatten]
117132

118133
class PreNet(nn.Module):
119134
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
@@ -189,10 +204,12 @@ def forward(self, encoder_seq_proj, query, t):
189204

190205

191206
class Decoder(nn.Module):
207+
# Class variable because its value doesn't change between classes
208+
# yet ought to be scoped by class because its a property of a Decoder
209+
max_r = 20
192210
def __init__(self, n_mels, decoder_dims, lstm_dims):
193211
super().__init__()
194-
self.max_r = 20
195-
self.r = None
212+
self.register_buffer('r', torch.tensor(1, dtype=torch.int))
196213
self.generating = False
197214
self.n_mels = n_mels
198215
self.prenet = PreNet(n_mels)
@@ -204,8 +221,7 @@ def __init__(self, n_mels, decoder_dims, lstm_dims):
204221
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
205222

206223
def zoneout(self, prev, current, p=0.1):
207-
device = prev.device
208-
assert prev.device == current.device
224+
device = next(self.parameters()).device # Use same device as parameters
209225
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
210226
return prev * mask + current * (1 - mask)
211227

@@ -279,17 +295,15 @@ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, ff
279295
self.init_model()
280296
self.num_params()
281297

282-
# Unfortunately I have to put these settings into params in order to save
283-
# if anyone knows a better way of doing this please open an issue in the repo
284-
self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
285-
self.r = nn.Parameter(torch.tensor(0).long(), requires_grad=False)
286-
287-
def set_r(self, r):
288-
self.r.data = torch.tensor(r)
289-
self.decoder.r = r
298+
self.register_buffer('step', torch.zeros(1, dtype=torch.long))
299+
300+
@property
301+
def r(self):
302+
return self.decoder.r.item()
290303

291-
def get_r(self):
292-
return self.r.item()
304+
@r.setter
305+
def r(self, value):
306+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
293307

294308
def forward(self, x, m, generate_gta=False):
295309
device = next(self.parameters()).device # use same device as parameters
@@ -351,7 +365,7 @@ def forward(self, x, m, generate_gta=False):
351365

352366
# For easy visualisation
353367
attn_scores = torch.cat(attn_scores, 1)
354-
attn_scores = attn_scores.cpu().data.numpy()
368+
# attn_scores = attn_scores.cpu().data.numpy()
355369

356370
return mel_outputs, linear, attn_scores
357371

@@ -430,11 +444,17 @@ def get_step(self):
430444
return self.step.data.item()
431445

432446
def reset_step(self):
433-
self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
447+
assert self.step is not None
448+
device = next(self.parameters()).device # use same device as parameters
449+
# assignment to parameters or buffers is overloaded, updates internal dict entry
450+
self.step = torch.zeros(1, dtype=torch.long, device=device)
434451

435-
def checkpoint(self, path):
452+
def checkpoint(self, path, optimizer):
453+
# Optimizer can be given as an argument because checkpoint function is
454+
# only useful in context of already existing training process.
436455
k_steps = self.get_step() // 1000
437456
self.save(f'{path}/checkpoint_{k_steps}k_steps.pyt')
457+
torch.save(optimizer.get_state(), f'{path}/checkpoint_{k_steps}k_steps_optim.pyt')
438458

439459
def log(self, path, msg):
440460
with open(path, 'a') as f:
@@ -447,17 +467,21 @@ def restore(self, path):
447467
else:
448468
print(f'\nLoading Weights: "{path}"\n')
449469
self.load(path)
450-
self.decoder.r = self.r.item()
451470

452-
def load(self, path, device='cpu'):
453-
# because PyTorch places on CPU by default, we follow those semantics by using CPU as default.
471+
def load(self, path):
472+
# Use device of model params as location for loaded state
473+
device = next(self.parameters()).device
454474
self.load_state_dict(torch.load(path, map_location=device), strict=False)
455475

456476
def save(self, path):
477+
# No optimizer argument because saving a model should not include data
478+
# only relevant in the training process - it should only be properties
479+
# of the model itself. Let caller take care of saving optimzier state.
457480
torch.save(self.state_dict(), path)
458481

459482
def num_params(self, print_out=True):
460483
parameters = filter(lambda p: p.requires_grad, self.parameters())
461484
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
462485
if print_out:
463486
print('Trainable Parameters: %.3fM' % parameters)
487+
return parameters

preprocess.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@
1010
from utils.files import get_files
1111

1212

13+
# Helper functions for argument types
14+
def valid_n_workers(num):
15+
n = int(num)
16+
if n < 1:
17+
raise argparse.ArgumentTypeError('%r must be an integer greater than 0' % num)
18+
return n
19+
1320
parser = argparse.ArgumentParser(description='Preprocessing for WaveRNN and Tacotron')
1421
parser.add_argument('--path', '-p', default=hp.wav_path, help='directly point to dataset path (overrides hparams.wav_path')
15-
parser.add_argument('--extension', '-e', default='.wav', help='file extension to search for in dataset folder')
22+
parser.add_argument('--extension', '-e', metavar='EXT', default='.wav', help='file extension to search for in dataset folder')
23+
parser.add_argument('--num_workers', '-w', metavar='EXT', type=valid_n_workers, default=cpu_count()-1, help='The number of worker threads to use for preprocessing')
1624
args = parser.parse_args()
1725

1826
extension = args.extension
@@ -60,13 +68,17 @@ def process_wav(path):
6068
with open(f'{paths.data}text_dict.pkl', 'wb') as f:
6169
pickle.dump(text_dict, f)
6270

63-
simple_table([('Sample Rate', hp.sample_rate),
64-
('Bit Depth', hp.bits),
65-
('Mu Law', hp.mu_law),
66-
('Hop Length', hp.hop_length),
67-
('CPU Count', cpu_count())])
71+
n_workers = max(1, args.num_workers)
72+
73+
simple_table([
74+
('Sample Rate', hp.sample_rate),
75+
('Bit Depth', hp.bits),
76+
('Mu Law', hp.mu_law),
77+
('Hop Length', hp.hop_length),
78+
('CPU Usage', f'{n_workers}/{cpu_count()}')
79+
])
6880

69-
pool = Pool(processes=cpu_count())
81+
pool = Pool(processes=n_workers)
7082
dataset = []
7183

7284
for i, (id, length) in enumerate(pool.imap_unordered(process_wav, wav_files), 1):

quick_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,6 @@
119119
m = torch.tensor(m).unsqueeze(0)
120120
m = (m + 4) / 8
121121

122-
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
122+
voc_model.generate(m, save_path, batched, target, overlap, hp.mu_law)
123123

124124
print('\n\nDone.\n')

0 commit comments

Comments
 (0)