forked from soobinseo/Transformer-TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_transformer.py
111 lines (73 loc) · 3.79 KB
/
train_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from preprocess import get_dataset, DataLoader, collate_fn_transformer
from network import *
from tensorboardX import SummaryWriter
import torchvision.utils as vutils
import os
from tqdm import tqdm
def adjust_learning_rate(optimizer, step_num, warmup_step=4000):
lr = hp.lr * warmup_step**0.5 * min(step_num * warmup_step**-1.5, step_num**-0.5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def main():
dataset = get_dataset()
global_step = 0
m = nn.DataParallel(Model().cuda())
m.train()
optimizer = t.optim.Adam(m.parameters(), lr=hp.lr)
pos_weight = t.FloatTensor([5.]).cuda()
writer = SummaryWriter()
for epoch in range(hp.epochs):
dataloader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn_transformer, drop_last=True, num_workers=16)
pbar = tqdm(dataloader)
for i, data in enumerate(pbar):
pbar.set_description("Processing at epoch %d"%epoch)
global_step += 1
if global_step < 400000:
adjust_learning_rate(optimizer, global_step)
character, mel, mel_input, pos_text, pos_mel, _ = data
stop_tokens = t.abs(pos_mel.ne(0).type(t.float) - 1)
character = character.cuda()
mel = mel.cuda()
mel_input = mel_input.cuda()
pos_text = pos_text.cuda()
pos_mel = pos_mel.cuda()
mel_pred, postnet_pred, attn_probs, stop_preds, attns_enc, attns_dec = m.forward(character, mel_input, pos_text, pos_mel)
mel_loss = nn.L1Loss()(mel_pred, mel)
post_mel_loss = nn.L1Loss()(postnet_pred, mel)
loss = mel_loss + post_mel_loss
writer.add_scalars('training_loss',{
'mel_loss':mel_loss,
'post_mel_loss':post_mel_loss,
}, global_step)
writer.add_scalars('alphas',{
'encoder_alpha':m.module.encoder.alpha.data,
'decoder_alpha':m.module.decoder.alpha.data,
}, global_step)
if global_step % hp.image_step == 1:
for i, prob in enumerate(attn_probs):
num_h = prob.size(0)
for j in range(4):
x = vutils.make_grid(prob[j*16] * 255)
writer.add_image('Attention_%d_0'%global_step, x, i*4+j)
for i, prob in enumerate(attns_enc):
num_h = prob.size(0)
for j in range(4):
x = vutils.make_grid(prob[j*16] * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j)
for i, prob in enumerate(attns_dec):
num_h = prob.size(0)
for j in range(4):
x = vutils.make_grid(prob[j*16] * 255)
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j)
optimizer.zero_grad()
# Calculate gradients
loss.backward()
nn.utils.clip_grad_norm_(m.parameters(), 1.)
# Update weights
optimizer.step()
if global_step % hp.save_step == 0:
t.save({'model':m.state_dict(),
'optimizer':optimizer.state_dict()},
os.path.join(hp.checkpoint_path,'checkpoint_transformer_%d.pth.tar' % global_step))
if __name__ == '__main__':
main()