Skip to content

Commit a46d4aa

Browse files
authored
Merge pull request fatchord#126 from TheButlah/master
Add Multi-GPU training, Griffin-Lim vocoder, safely restore checkpoints, pathlib, multiple hparams files
2 parents 6709c7d + e1c8421 commit a46d4aa

20 files changed

+942
-503
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# PyCharm files
1+
# IDE files
22
.idea
3+
.vscode
34

45
# Mac files
56
.DS_Store

README.md

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Pytorch implementation of Deepmind's WaveRNN model from [Efficient Neural Audio
88

99
# Installation
1010

11-
Ensure you have:
11+
Ensure you have:
1212

1313
* Python >= 3.6
1414
* [Pytorch 1 with CUDA](https://pytorch.org/)
@@ -37,20 +37,20 @@ You can also use that script to generate custom tts sentences and/or use '-u' to
3737

3838
Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) Dataset.
3939

40-
Edit **hparams.py**, point **wav_path** to your dataset and run:
40+
Edit **hparams.py**, point **wav_path** to your dataset and run:
4141

4242
> python preprocess.py
4343
4444
or use preprocess.py --path to point directly to the dataset
4545
___
4646

47-
Here's my recommendation on what order to run things:
47+
Here's my recommendation on what order to run things:
4848

4949
1 - Train Tacotron with:
5050

5151
> python train_tacotron.py
5252
53-
2 - You can leave that finish training or at any point you can use:
53+
2 - You can leave that finish training or at any point you can use:
5454

5555
> python train_tacotron.py --force_gta
5656
@@ -64,11 +64,11 @@ NB: You can always just run train_wavernn.py without --gta if you're not interes
6464

6565
4 - Generate Sentences with both models using:
6666

67-
> python gen_tacotron.py
67+
> python gen_tacotron.py wavernn
6868
6969
this will generate default sentences. If you want generate custom sentences you can use
7070

71-
> python gen_tacotron.py --input_text "this is whatever you want it to be"
71+
> python gen_tacotron.py --input_text "this is whatever you want it to be" wavernn
7272
7373
And finally, you can always use --help on any of those scripts to see what options are available :)
7474

@@ -84,7 +84,7 @@ Currently there are two pretrained models available in the /pretrained/ folder':
8484

8585
Both are trained on LJSpeech
8686

87-
* WaveRNN (Mixture of Logistics output) trained to 800k steps
87+
* WaveRNN (Mixture of Logistics output) trained to 800k steps
8888
* Tacotron trained to 180k steps
8989

9090
____
@@ -100,7 +100,3 @@ ____
100100
* [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron)
101101
* [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
102102
* Special thanks to github users [G-Wang](https://github.com/G-Wang), [geneing](https://github.com/geneing) & [erogol](https://github.com/erogol)
103-
104-
105-
106-

gen_tacotron.py

Lines changed: 102 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,68 @@
11
import torch
22
from models.fatchord_version import WaveRNN
3-
import hparams as hp
3+
from utils import hparams as hp
44
from utils.text.symbols import symbols
55
from utils.paths import Paths
66
from models.tacotron import Tacotron
77
import argparse
88
from utils.text import text_to_sequence
99
from utils.display import save_attention, simple_table
10+
from utils.dsp import reconstruct_waveform, save_wav
11+
import numpy as np
1012

1113
if __name__ == "__main__":
1214

1315
# Parse Arguments
1416
parser = argparse.ArgumentParser(description='TTS Generator')
1517
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
16-
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
17-
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
18-
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
19-
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
20-
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
18+
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
2119
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
2220
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
23-
parser.set_defaults(batched=hp.voc_gen_batched)
24-
parser.set_defaults(target=hp.voc_target)
25-
parser.set_defaults(overlap=hp.voc_overlap)
21+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
22+
2623
parser.set_defaults(input_text=None)
2724
parser.set_defaults(weights_path=None)
28-
parser.set_defaults(save_attention=False)
25+
26+
# name of subcommand goes to args.vocoder
27+
subparsers = parser.add_subparsers(required=True, dest='vocoder')
28+
29+
wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
30+
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
31+
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
32+
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
33+
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
34+
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
35+
wr_parser.set_defaults(batched=None)
36+
37+
gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
38+
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')
39+
2940
args = parser.parse_args()
3041

31-
batched = args.batched
32-
target = args.target
33-
overlap = args.overlap
42+
if args.vocoder in ['griffinlim', 'gl']:
43+
args.vocoder = 'griffinlim'
44+
elif args.vocoder in ['wavernn', 'wr']:
45+
args.vocoder = 'wavernn'
46+
else:
47+
raise argparse.ArgumentError('Must provide a valid vocoder type!')
48+
49+
hp.configure(args.hp_file) # Load hparams from file
50+
# set defaults for any arguments that depend on hparams
51+
if args.vocoder == 'wavernn':
52+
if args.target is None:
53+
args.target = hp.voc_target
54+
if args.overlap is None:
55+
args.overlap = hp.voc_overlap
56+
if args.batched is None:
57+
args.batched = hp.voc_gen_batched
58+
59+
batched = args.batched
60+
target = args.target
61+
overlap = args.overlap
62+
3463
input_text = args.input_text
35-
weights_path = args.weights_path
36-
save_attn = args.save_attention
64+
tts_weights = args.tts_weights
65+
save_attn = args.save_attn
3766

3867
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
3968

@@ -43,23 +72,24 @@
4372
device = torch.device('cpu')
4473
print('Using device:', device)
4574

46-
print('\nInitialising WaveRNN Model...\n')
47-
48-
# Instantiate WaveRNN Model
49-
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
50-
fc_dims=hp.voc_fc_dims,
51-
bits=hp.bits,
52-
pad=hp.voc_pad,
53-
upsample_factors=hp.voc_upsample_factors,
54-
feat_dims=hp.num_mels,
55-
compute_dims=hp.voc_compute_dims,
56-
res_out_dims=hp.voc_res_out_dims,
57-
res_blocks=hp.voc_res_blocks,
58-
hop_length=hp.hop_length,
59-
sample_rate=hp.sample_rate,
60-
mode=hp.voc_mode).to(device)
61-
62-
voc_model.restore(paths.voc_latest_weights)
75+
if args.vocoder == 'wavernn':
76+
print('\nInitialising WaveRNN Model...\n')
77+
# Instantiate WaveRNN Model
78+
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
79+
fc_dims=hp.voc_fc_dims,
80+
bits=hp.bits,
81+
pad=hp.voc_pad,
82+
upsample_factors=hp.voc_upsample_factors,
83+
feat_dims=hp.num_mels,
84+
compute_dims=hp.voc_compute_dims,
85+
res_out_dims=hp.voc_res_out_dims,
86+
res_blocks=hp.voc_res_blocks,
87+
hop_length=hp.hop_length,
88+
sample_rate=hp.sample_rate,
89+
mode=hp.voc_mode).to(device)
90+
91+
voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
92+
voc_model.load(voc_load_path)
6393

6494
print('\nInitialising Tacotron Model...\n')
6595

@@ -75,42 +105,64 @@
75105
lstm_dims=hp.tts_lstm_dims,
76106
postnet_K=hp.tts_postnet_K,
77107
num_highways=hp.tts_num_highways,
78-
dropout=hp.tts_dropout).to(device)
108+
dropout=hp.tts_dropout,
109+
stop_threshold=hp.tts_stop_threshold).to(device)
79110

80-
tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
81-
tts_model.restore(tts_restore_path)
111+
tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
112+
tts_model.load(tts_load_path)
82113

83114
if input_text:
84115
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
85116
else:
86117
with open('sentences.txt') as f:
87118
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
88119

89-
voc_k = voc_model.get_step() // 1000
90-
tts_k = tts_model.get_step() // 1000
91-
92-
simple_table([('WaveRNN', str(voc_k) + 'k'),
93-
('Tacotron', str(tts_k) + 'k'),
94-
('r', tts_model.r.item()),
95-
('Generation Mode', 'Batched' if batched else 'Unbatched'),
96-
('Target Samples', target if batched else 'N/A'),
97-
('Overlap Samples', overlap if batched else 'N/A')])
120+
if args.vocoder == 'wavernn':
121+
voc_k = voc_model.get_step() // 1000
122+
tts_k = tts_model.get_step() // 1000
123+
124+
simple_table([('Tacotron', str(tts_k) + 'k'),
125+
('r', tts_model.r),
126+
('Vocoder Type', 'WaveRNN'),
127+
('WaveRNN', str(voc_k) + 'k'),
128+
('Generation Mode', 'Batched' if batched else 'Unbatched'),
129+
('Target Samples', target if batched else 'N/A'),
130+
('Overlap Samples', overlap if batched else 'N/A')])
131+
132+
elif args.vocoder == 'griffinlim':
133+
tts_k = tts_model.get_step() // 1000
134+
simple_table([('Tacotron', str(tts_k) + 'k'),
135+
('r', tts_model.r),
136+
('Vocoder Type', 'Griffin-Lim'),
137+
('GL Iters', args.iters)])
98138

99139
for i, x in enumerate(inputs, 1):
100140

101141
print(f'\n| Generating {i}/{len(inputs)}')
102142
_, m, attention = tts_model.generate(x)
143+
# Fix mel spectrogram scaling to be from 0 to 1
144+
m = (m + 4) / 8
145+
np.clip(m, 0, 1, out=m)
146+
147+
if args.vocoder == 'griffinlim':
148+
v_type = args.vocoder
149+
elif args.vocoder == 'wavernn' and args.batched:
150+
v_type = 'wavernn_batched'
151+
else:
152+
v_type = 'wavernn_unbatched'
103153

104154
if input_text:
105-
save_path = f'{paths.tts_output}__input_{input_text[:10]}_{tts_k}k.wav'
155+
save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
106156
else:
107-
save_path = f'{paths.tts_output}{i}_batched{str(batched)}_{tts_k}k.wav'
157+
save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'
108158

109159
if save_attn: save_attention(attention, save_path)
110160

111-
m = torch.tensor(m).unsqueeze(0)
112-
m = (m + 4) / 8
113-
114-
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
161+
if args.vocoder == 'wavernn':
162+
m = torch.tensor(m).unsqueeze(0)
163+
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
164+
elif args.vocoder == 'griffinlim':
165+
wav = reconstruct_waveform(m, n_iter=args.iters)
166+
save_wav(wav, save_path)
115167

116168
print('\n\nDone.\n')

0 commit comments

Comments
 (0)