|
1 | 1 | import torch
|
2 | 2 | from models.fatchord_version import WaveRNN
|
3 |
| -import hparams as hp |
| 3 | +from utils import hparams as hp |
4 | 4 | from utils.text.symbols import symbols
|
5 | 5 | from utils.paths import Paths
|
6 | 6 | from models.tacotron import Tacotron
|
7 | 7 | import argparse
|
8 | 8 | from utils.text import text_to_sequence
|
9 | 9 | from utils.display import save_attention, simple_table
|
| 10 | +from utils.dsp import reconstruct_waveform, save_wav |
| 11 | +import numpy as np |
10 | 12 |
|
11 | 13 | if __name__ == "__main__":
|
12 | 14 |
|
13 | 15 | # Parse Arguments
|
14 | 16 | parser = argparse.ArgumentParser(description='TTS Generator')
|
15 | 17 | parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
|
16 |
| - parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation') |
17 |
| - parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation') |
18 |
| - parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index') |
19 |
| - parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples') |
20 |
| - parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights') |
| 18 | + parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights') |
21 | 19 | parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
|
22 | 20 | parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
23 |
| - parser.set_defaults(batched=hp.voc_gen_batched) |
24 |
| - parser.set_defaults(target=hp.voc_target) |
25 |
| - parser.set_defaults(overlap=hp.voc_overlap) |
| 21 | + parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') |
| 22 | + |
26 | 23 | parser.set_defaults(input_text=None)
|
27 | 24 | parser.set_defaults(weights_path=None)
|
28 |
| - parser.set_defaults(save_attention=False) |
| 25 | + |
| 26 | + # name of subcommand goes to args.vocoder |
| 27 | + subparsers = parser.add_subparsers(required=True, dest='vocoder') |
| 28 | + |
| 29 | + wr_parser = subparsers.add_parser('wavernn', aliases=['wr']) |
| 30 | + wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation') |
| 31 | + wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation') |
| 32 | + wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples') |
| 33 | + wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index') |
| 34 | + wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights') |
| 35 | + wr_parser.set_defaults(batched=None) |
| 36 | + |
| 37 | + gl_parser = subparsers.add_parser('griffinlim', aliases=['gl']) |
| 38 | + gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations') |
| 39 | + |
29 | 40 | args = parser.parse_args()
|
30 | 41 |
|
31 |
| - batched = args.batched |
32 |
| - target = args.target |
33 |
| - overlap = args.overlap |
| 42 | + if args.vocoder in ['griffinlim', 'gl']: |
| 43 | + args.vocoder = 'griffinlim' |
| 44 | + elif args.vocoder in ['wavernn', 'wr']: |
| 45 | + args.vocoder = 'wavernn' |
| 46 | + else: |
| 47 | + raise argparse.ArgumentError('Must provide a valid vocoder type!') |
| 48 | + |
| 49 | + hp.configure(args.hp_file) # Load hparams from file |
| 50 | + # set defaults for any arguments that depend on hparams |
| 51 | + if args.vocoder == 'wavernn': |
| 52 | + if args.target is None: |
| 53 | + args.target = hp.voc_target |
| 54 | + if args.overlap is None: |
| 55 | + args.overlap = hp.voc_overlap |
| 56 | + if args.batched is None: |
| 57 | + args.batched = hp.voc_gen_batched |
| 58 | + |
| 59 | + batched = args.batched |
| 60 | + target = args.target |
| 61 | + overlap = args.overlap |
| 62 | + |
34 | 63 | input_text = args.input_text
|
35 |
| - weights_path = args.weights_path |
36 |
| - save_attn = args.save_attention |
| 64 | + tts_weights = args.tts_weights |
| 65 | + save_attn = args.save_attn |
37 | 66 |
|
38 | 67 | paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
39 | 68 |
|
|
43 | 72 | device = torch.device('cpu')
|
44 | 73 | print('Using device:', device)
|
45 | 74 |
|
46 |
| - print('\nInitialising WaveRNN Model...\n') |
47 |
| - |
48 |
| - # Instantiate WaveRNN Model |
49 |
| - voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, |
50 |
| - fc_dims=hp.voc_fc_dims, |
51 |
| - bits=hp.bits, |
52 |
| - pad=hp.voc_pad, |
53 |
| - upsample_factors=hp.voc_upsample_factors, |
54 |
| - feat_dims=hp.num_mels, |
55 |
| - compute_dims=hp.voc_compute_dims, |
56 |
| - res_out_dims=hp.voc_res_out_dims, |
57 |
| - res_blocks=hp.voc_res_blocks, |
58 |
| - hop_length=hp.hop_length, |
59 |
| - sample_rate=hp.sample_rate, |
60 |
| - mode=hp.voc_mode).to(device) |
61 |
| - |
62 |
| - voc_model.restore(paths.voc_latest_weights) |
| 75 | + if args.vocoder == 'wavernn': |
| 76 | + print('\nInitialising WaveRNN Model...\n') |
| 77 | + # Instantiate WaveRNN Model |
| 78 | + voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, |
| 79 | + fc_dims=hp.voc_fc_dims, |
| 80 | + bits=hp.bits, |
| 81 | + pad=hp.voc_pad, |
| 82 | + upsample_factors=hp.voc_upsample_factors, |
| 83 | + feat_dims=hp.num_mels, |
| 84 | + compute_dims=hp.voc_compute_dims, |
| 85 | + res_out_dims=hp.voc_res_out_dims, |
| 86 | + res_blocks=hp.voc_res_blocks, |
| 87 | + hop_length=hp.hop_length, |
| 88 | + sample_rate=hp.sample_rate, |
| 89 | + mode=hp.voc_mode).to(device) |
| 90 | + |
| 91 | + voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights |
| 92 | + voc_model.load(voc_load_path) |
63 | 93 |
|
64 | 94 | print('\nInitialising Tacotron Model...\n')
|
65 | 95 |
|
|
75 | 105 | lstm_dims=hp.tts_lstm_dims,
|
76 | 106 | postnet_K=hp.tts_postnet_K,
|
77 | 107 | num_highways=hp.tts_num_highways,
|
78 |
| - dropout=hp.tts_dropout).to(device) |
| 108 | + dropout=hp.tts_dropout, |
| 109 | + stop_threshold=hp.tts_stop_threshold).to(device) |
79 | 110 |
|
80 |
| - tts_restore_path = weights_path if weights_path else paths.tts_latest_weights |
81 |
| - tts_model.restore(tts_restore_path) |
| 111 | + tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights |
| 112 | + tts_model.load(tts_load_path) |
82 | 113 |
|
83 | 114 | if input_text:
|
84 | 115 | inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
|
85 | 116 | else:
|
86 | 117 | with open('sentences.txt') as f:
|
87 | 118 | inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
|
88 | 119 |
|
89 |
| - voc_k = voc_model.get_step() // 1000 |
90 |
| - tts_k = tts_model.get_step() // 1000 |
91 |
| - |
92 |
| - simple_table([('WaveRNN', str(voc_k) + 'k'), |
93 |
| - ('Tacotron', str(tts_k) + 'k'), |
94 |
| - ('r', tts_model.r.item()), |
95 |
| - ('Generation Mode', 'Batched' if batched else 'Unbatched'), |
96 |
| - ('Target Samples', target if batched else 'N/A'), |
97 |
| - ('Overlap Samples', overlap if batched else 'N/A')]) |
| 120 | + if args.vocoder == 'wavernn': |
| 121 | + voc_k = voc_model.get_step() // 1000 |
| 122 | + tts_k = tts_model.get_step() // 1000 |
| 123 | + |
| 124 | + simple_table([('Tacotron', str(tts_k) + 'k'), |
| 125 | + ('r', tts_model.r), |
| 126 | + ('Vocoder Type', 'WaveRNN'), |
| 127 | + ('WaveRNN', str(voc_k) + 'k'), |
| 128 | + ('Generation Mode', 'Batched' if batched else 'Unbatched'), |
| 129 | + ('Target Samples', target if batched else 'N/A'), |
| 130 | + ('Overlap Samples', overlap if batched else 'N/A')]) |
| 131 | + |
| 132 | + elif args.vocoder == 'griffinlim': |
| 133 | + tts_k = tts_model.get_step() // 1000 |
| 134 | + simple_table([('Tacotron', str(tts_k) + 'k'), |
| 135 | + ('r', tts_model.r), |
| 136 | + ('Vocoder Type', 'Griffin-Lim'), |
| 137 | + ('GL Iters', args.iters)]) |
98 | 138 |
|
99 | 139 | for i, x in enumerate(inputs, 1):
|
100 | 140 |
|
101 | 141 | print(f'\n| Generating {i}/{len(inputs)}')
|
102 | 142 | _, m, attention = tts_model.generate(x)
|
| 143 | + # Fix mel spectrogram scaling to be from 0 to 1 |
| 144 | + m = (m + 4) / 8 |
| 145 | + np.clip(m, 0, 1, out=m) |
| 146 | + |
| 147 | + if args.vocoder == 'griffinlim': |
| 148 | + v_type = args.vocoder |
| 149 | + elif args.vocoder == 'wavernn' and args.batched: |
| 150 | + v_type = 'wavernn_batched' |
| 151 | + else: |
| 152 | + v_type = 'wavernn_unbatched' |
103 | 153 |
|
104 | 154 | if input_text:
|
105 |
| - save_path = f'{paths.tts_output}__input_{input_text[:10]}_{tts_k}k.wav' |
| 155 | + save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav' |
106 | 156 | else:
|
107 |
| - save_path = f'{paths.tts_output}{i}_batched{str(batched)}_{tts_k}k.wav' |
| 157 | + save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav' |
108 | 158 |
|
109 | 159 | if save_attn: save_attention(attention, save_path)
|
110 | 160 |
|
111 |
| - m = torch.tensor(m).unsqueeze(0) |
112 |
| - m = (m + 4) / 8 |
113 |
| - |
114 |
| - voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law) |
| 161 | + if args.vocoder == 'wavernn': |
| 162 | + m = torch.tensor(m).unsqueeze(0) |
| 163 | + voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law) |
| 164 | + elif args.vocoder == 'griffinlim': |
| 165 | + wav = reconstruct_waveform(m, n_iter=args.iters) |
| 166 | + save_wav(wav, save_path) |
115 | 167 |
|
116 | 168 | print('\n\nDone.\n')
|
0 commit comments