Skip to content

Commit 173ec17

Browse files
authored
Merge pull request #107 from TheButlah/master
Added CPU support, other fixes too
2 parents d0cbddf + 1a5b84a commit 173ec17

17 files changed

+316
-221
lines changed

.gitignore

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,47 @@
1+
# PyCharm files
12
.idea
3+
4+
# Mac files
5+
.DS_Store
6+
7+
# Environments
8+
.env
9+
.venv
10+
env/
11+
venv/
12+
ENV/
13+
env.bak/
14+
venv.bak/
15+
16+
# Byte-compiled / optimized / DLL files
17+
__pycache__/
18+
*.py[cod]
19+
*$py.class
20+
21+
# Distribution / packaging
22+
.Python
23+
build/
24+
develop-eggs/
25+
dist/
26+
downloads/
27+
eggs/
28+
.eggs/
29+
lib/
30+
lib64/
31+
parts/
32+
sdist/
33+
var/
34+
wheels/
35+
pip-wheel-metadata/
36+
share/python-wheels/
37+
*.egg-info/
38+
.installed.cfg
39+
*.egg
40+
MANIFEST
41+
42+
# Installer logs
43+
pip-log.txt
44+
pip-delete-this-directory.txt
45+
46+
# Jupyter Notebook
47+
.ipynb_checkpoints

gen_tacotron.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from utils.text import text_to_sequence
99
from utils.display import save_attention, simple_table
1010

11-
if __name__ == "__main__" :
11+
if __name__ == "__main__":
1212

1313
# Parse Arguments
1414
parser = argparse.ArgumentParser(description='TTS Generator')
@@ -19,6 +19,7 @@
1919
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
2020
parser.add_argument('--weights_path', '-w', type=str, help='[string/path] Load in different Tacotron Weights')
2121
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
22+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
2223
parser.set_defaults(batched=hp.voc_gen_batched)
2324
parser.set_defaults(target=hp.voc_target)
2425
parser.set_defaults(overlap=hp.voc_overlap)
@@ -36,6 +37,12 @@
3637

3738
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
3839

40+
if not args.force_cpu and torch.cuda.is_available():
41+
device = torch.device('cuda')
42+
else:
43+
device = torch.device('cpu')
44+
print('Using device:', device)
45+
3946
print('\nInitialising WaveRNN Model...\n')
4047

4148
# Instantiate WaveRNN Model
@@ -50,7 +57,7 @@
5057
res_blocks=hp.voc_res_blocks,
5158
hop_length=hp.hop_length,
5259
sample_rate=hp.sample_rate,
53-
mode=hp.voc_mode).cuda()
60+
mode=hp.voc_mode).to(device)
5461

5562
voc_model.restore(paths.voc_latest_weights)
5663

@@ -68,15 +75,15 @@
6875
lstm_dims=hp.tts_lstm_dims,
6976
postnet_K=hp.tts_postnet_K,
7077
num_highways=hp.tts_num_highways,
71-
dropout=hp.tts_dropout).cuda()
78+
dropout=hp.tts_dropout).to(device)
7279

7380
tts_restore_path = weights_path if weights_path else paths.tts_latest_weights
7481
tts_model.restore(tts_restore_path)
7582

76-
if input_text :
83+
if input_text:
7784
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
78-
else :
79-
with open('sentences.txt') as f :
85+
else:
86+
with open('sentences.txt') as f:
8087
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
8188

8289
voc_k = voc_model.get_step() // 1000
@@ -89,21 +96,21 @@
8996
('Target Samples', target if batched else 'N/A'),
9097
('Overlap Samples', overlap if batched else 'N/A')])
9198

92-
for i, x in enumerate(inputs, 1) :
99+
for i, x in enumerate(inputs, 1):
93100

94101
print(f'\n| Generating {i}/{len(inputs)}')
95102
_, m, attention = tts_model.generate(x)
96103

97-
if input_text :
104+
if input_text:
98105
save_path = f'{paths.tts_output}__input_{input_text[:10]}_{tts_k}k.wav'
99-
else :
106+
else:
100107
save_path = f'{paths.tts_output}{i}_batched{str(batched)}_{tts_k}k.wav'
101108

102-
if save_attn : save_attention(attention, save_path)
109+
if save_attn: save_attention(attention, save_path)
103110

104111
m = torch.tensor(m).unsqueeze(0)
105112
m = (m + 4) / 8
106113

107114
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
108115

109-
print('\n\nDone.\n')
116+
print('\n\nDone.\n')

gen_wavernn.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@
77
import argparse
88

99

10-
def gen_testset(model, test_set, samples, batched, target, overlap, save_path) :
10+
def gen_testset(model, test_set, samples, batched, target, overlap, save_path):
1111

1212
k = model.get_step() // 1000
1313

1414
for i, (m, x) in enumerate(test_set, 1):
1515

16-
if i > samples : break
16+
if i > samples: break
1717

1818
print('\n| Generating: %i/%i' % (i, samples))
1919

2020
x = x[0].numpy()
2121

2222
bits = 16 if hp.voc_mode == 'MOL' else hp.bits
2323

24-
if hp.mu_law and hp.voc_mode != 'MOL' :
24+
if hp.mu_law and hp.voc_mode != 'MOL':
2525
x = decode_mu_law(x, 2**bits, from_labels=True)
26-
else :
26+
else:
2727
x = label_2_float(x, bits)
2828

2929
save_wav(x, f'{save_path}{k}k_steps_{i}_target.wav')
@@ -34,7 +34,7 @@ def gen_testset(model, test_set, samples, batched, target, overlap, save_path) :
3434
_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
3535

3636

37-
def gen_from_file(model, load_path, save_path, batched, target, overlap) :
37+
def gen_from_file(model, load_path, save_path, batched, target, overlap):
3838

3939
k = model.get_step() // 1000
4040
file_name = load_path.split('/')[-1]
@@ -62,6 +62,7 @@ def gen_from_file(model, load_path, save_path, batched, target, overlap) :
6262
parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset')
6363
parser.add_argument('--weights', '-w', type=str, help='[string/path] checkpoint file to load weights from')
6464
parser.add_argument('--gta', '-g', dest='use_gta', action='store_true', help='Generate from GTA testset')
65+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
6566

6667
parser.set_defaults(batched=hp.voc_gen_batched)
6768
parser.set_defaults(samples=hp.voc_gen_at_checkpoint)
@@ -80,6 +81,12 @@ def gen_from_file(model, load_path, save_path, batched, target, overlap) :
8081
file = args.file
8182
gta = args.gta
8283

84+
if not args.force_cpu and torch.cuda.is_available():
85+
device = torch.device('cuda')
86+
else:
87+
device = torch.device('cpu')
88+
print('Using device:', device)
89+
8390
print('\nInitialising Model...\n')
8491

8592
model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
@@ -93,7 +100,7 @@ def gen_from_file(model, load_path, save_path, batched, target, overlap) :
93100
res_blocks=hp.voc_res_blocks,
94101
hop_length=hp.hop_length,
95102
sample_rate=hp.sample_rate,
96-
mode=hp.voc_mode).cuda()
103+
mode=hp.voc_mode).to(device)
97104

98105
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
99106

@@ -107,9 +114,9 @@ def gen_from_file(model, load_path, save_path, batched, target, overlap) :
107114

108115
_, test_set = get_vocoder_datasets(paths.data, 1, gta)
109116

110-
if file :
117+
if file:
111118
gen_from_file(model, file, paths.voc_output, batched, target, overlap)
112-
else :
119+
else:
113120
gen_testset(model, test_set, samples, batched, target, overlap, paths.voc_output)
114121

115122
print('\n\nExiting...\n')

models/__init__.py

Whitespace-only changes.

models/deepmind_version.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from utils.display import *
55
from utils.dsp import *
66

7-
class WaveRNN(nn.Module) :
8-
def __init__(self, hidden_size=896, quantisation=256) :
7+
class WaveRNN(nn.Module):
8+
def __init__(self, hidden_size=896, quantisation=256):
99
super(WaveRNN, self).__init__()
1010

1111
self.hidden_size = hidden_size
@@ -33,7 +33,7 @@ def __init__(self, hidden_size=896, quantisation=256) :
3333
self.num_params()
3434

3535

36-
def forward(self, prev_y, prev_hidden, current_coarse) :
36+
def forward(self, prev_y, prev_hidden, current_coarse):
3737

3838
# Main matmul - the projection is split 3 ways
3939
R_hidden = self.R(prev_hidden)
@@ -71,9 +71,10 @@ def forward(self, prev_y, prev_hidden, current_coarse) :
7171
return out_coarse, out_fine, hidden
7272

7373

74-
def generate(self, seq_len) :
75-
76-
with torch.no_grad() :
74+
def generate(self, seq_len):
75+
device = next(self.parameters()).device # use same device as parameters
76+
77+
with torch.no_grad():
7778

7879
# First split up the biases for the gates
7980
b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
@@ -84,17 +85,17 @@ def generate(self, seq_len) :
8485
c_outputs, f_outputs = [], []
8586

8687
# Some initial inputs
87-
out_coarse = torch.LongTensor([0]).cuda()
88-
out_fine = torch.LongTensor([0]).cuda()
88+
out_coarse = torch.tensor([0], dtype=torch.long, device=device)
89+
out_fine = torch.tensor([0], dtype=torch.long, device=device)
8990

9091
# We'll meed a hidden state
91-
hidden = self.init_hidden()
92+
hidden = self.get_initial_hidden()
9293

9394
# Need a clock for display
9495
start = time.time()
9596

9697
# Loop for generation
97-
for i in range(seq_len) :
98+
for i in range(seq_len):
9899

99100
# Split into two hidden states
100101
hidden_coarse, hidden_fine = \
@@ -162,10 +163,11 @@ def generate(self, seq_len) :
162163

163164
return output, coarse, fine
164165

165-
def init_hidden(self, batch_size=1) :
166-
return torch.zeros(batch_size, self.hidden_size).cuda()
166+
def get_initial_hidden(self, batch_size=1):
167+
device = next(self.parameters()).device # use same device as parameters
168+
return torch.zeros(batch_size, self.hidden_size, device=device)
167169

168-
def num_params(self) :
170+
def num_params(self):
169171
parameters = filter(lambda p: p.requires_grad, self.parameters())
170172
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
171-
print('Trainable Parameters: %.3f million' % parameters)
173+
print('Trainable Parameters: %.3f million' % parameters)

0 commit comments

Comments
 (0)