12
12
from utils .paths import Paths
13
13
import argparse
14
14
from utils import data_parallel_workaround
15
+ import os
15
16
16
17
17
- def voc_train_loop (model , loss_func , optimiser , train_set , test_set , lr , total_steps , device ):
18
+ def voc_train_loop (model , loss_func , optimizer , train_set , test_set , lr , total_steps , device ):
18
19
19
- for p in optimiser .param_groups : p ['lr' ] = lr
20
+ for p in optimizer .param_groups : p ['lr' ] = lr
20
21
21
22
total_iters = len (train_set )
22
23
epochs = (total_steps - model .get_step ()) // total_iters + 1
@@ -46,13 +47,13 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
46
47
47
48
loss = loss_func (y_hat , y )
48
49
49
- optimiser .zero_grad ()
50
+ optimizer .zero_grad ()
50
51
loss .backward ()
51
52
if hp .voc_clip_grad_norm is not None :
52
53
grad_norm = torch .nn .utils .clip_grad_norm_ (model .parameters (), hp .voc_clip_grad_norm )
53
54
if np .isnan (grad_norm ):
54
55
print ('grad_norm was NaN!' )
55
- optimiser .step ()
56
+ optimizer .step ()
56
57
running_loss += loss .item ()
57
58
58
59
speed = i / (time .time () - start )
@@ -64,11 +65,14 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
64
65
if step % hp .voc_checkpoint_every == 0 :
65
66
gen_testset (model , test_set , hp .voc_gen_at_checkpoint , hp .voc_gen_batched ,
66
67
hp .voc_target , hp .voc_overlap , paths .voc_output )
67
- model .checkpoint (paths .voc_checkpoints )
68
+ model .checkpoint (paths .voc_checkpoints , optimizer )
68
69
69
70
msg = f'| Epoch: { e } /{ epochs } ({ i } /{ total_iters } ) | Loss: { avg_loss :.4f} | { speed :.1f} steps/s | Step: { k } k | '
70
71
stream (msg )
71
72
73
+ # Must save latest optimizer state to ensure that resuming training
74
+ # doesn't produce artifacts
75
+ torch .save (optimizer .state_dict (), paths .tts_latest_optim )
72
76
model .save (paths .voc_latest_weights )
73
77
model .log (paths .voc_log , msg )
74
78
print (' ' )
@@ -123,7 +127,10 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
123
127
124
128
voc_model .restore (paths .voc_latest_weights )
125
129
126
- optimiser = optim .Adam (voc_model .parameters ())
130
+ optimizer = optim .Adam (voc_model .parameters ())
131
+ if os .path .isfile (paths .voc_latest_optim ):
132
+ print (f'Loading Optimizer State: "{ paths .voc_latest_optim } "' )
133
+ optimizer .load_state_dict (torch .load (paths .voc_latest_optim ))
127
134
128
135
train_set , test_set = get_vocoder_datasets (paths .data , batch_size , train_gta )
129
136
@@ -137,7 +144,7 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
137
144
138
145
loss_func = F .cross_entropy if voc_model .mode == 'RAW' else discretized_mix_logistic_loss
139
146
140
- voc_train_loop (voc_model , loss_func , optimiser , train_set , test_set , lr , total_steps , device )
147
+ voc_train_loop (voc_model , loss_func , optimizer , train_set , test_set , lr , total_steps , device )
141
148
142
149
print ('Training Complete.' )
143
150
print ('To continue training increase voc_total_steps in hparams.py or use --force_train' )
0 commit comments