Skip to content

Commit 94302af

Browse files
author
Ryan Butler
committed
Fixed missing import, added type annotation for IDE suggestions
+ Added type annotation for WaveRNN and Tacotron in train and generate files for each model * Fixed missing import for numpy in `fatchord_version.py` and `deepmind_version.py`
1 parent 2d66a04 commit 94302af

File tree

5 files changed

+8
-5
lines changed

5 files changed

+8
-5
lines changed

gen_wavernn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import argparse
88

99

10-
def gen_testset(model, test_set, samples, batched, target, overlap, save_path):
10+
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path):
1111

1212
k = model.get_step() // 1000
1313

@@ -34,7 +34,7 @@ def gen_testset(model, test_set, samples, batched, target, overlap, save_path):
3434
_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
3535

3636

37-
def gen_from_file(model, load_path, save_path, batched, target, overlap):
37+
def gen_from_file(model: WaveRNN, load_path, save_path, batched, target, overlap):
3838

3939
k = model.get_step() // 1000
4040
file_name = load_path.split('/')[-1]

models/deepmind_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
from utils.display import *
55
from utils.dsp import *
6+
import numpy as np
67

78
class WaveRNN(nn.Module):
89
def __init__(self, hidden_size=896, quantisation=256):

models/fatchord_version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from utils.display import *
66
from utils.dsp import *
77
import os
8+
import numpy as np
89

910

1011
class ResBlock(nn.Module):

train_tacotron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
def np_now(x): return x.detach().cpu().numpy()
1616

1717

18-
def tts_train_loop(model, optimizer, train_set, lr, train_steps, attn_example):
18+
def tts_train_loop(model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
1919
device = next(model.parameters()).device # use same device as model parameters
2020

2121
for p in optimizer.param_groups: p['lr'] = lr
@@ -81,7 +81,7 @@ def tts_train_loop(model, optimizer, train_set, lr, train_steps, attn_example):
8181
print(' ')
8282

8383

84-
def create_gta_features(model, train_set, save_path):
84+
def create_gta_features(model: Tacotron, train_set, save_path):
8585
device = next(model.parameters()).device # use same device as model parameters
8686

8787
iters = len(train_set)

train_wavernn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616

1717

18-
def voc_train_loop(model, loss_func, optimizer, train_set, test_set, lr, total_steps):
18+
def voc_train_loop(model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps):
1919
# Use same device as model parameters
2020
device = next(model.parameters()).device
2121

@@ -56,6 +56,7 @@ def voc_train_loop(model, loss_func, optimizer, train_set, test_set, lr, total_s
5656
if np.isnan(grad_norm):
5757
print('grad_norm was NaN!')
5858
optimizer.step()
59+
5960
running_loss += loss.item()
6061

6162
speed = i / (time.time() - start)

0 commit comments

Comments
 (0)