Skip to content

Commit 6422d25

Browse files
committed
Fixed missing import in train_tacotron.py, now saving optimizer state
1 parent bcb67c2 commit 6422d25

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

models/fatchord_version.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,9 +388,12 @@ def xfade_and_unfold(self, y, target, overlap):
388388
def get_step(self):
389389
return self.step.data.item()
390390

391-
def checkpoint(self, path):
391+
def checkpoint(self, path, optimizer):
392+
# Optimizer can be given as an argument because checkpoint function is
393+
# only useful in context of already existing training process.
392394
k_steps = self.get_step() // 1000
393395
self.save(f'{path}/checkpoint_{k_steps}k_steps.pyt')
396+
torch.save(optimizer.get_state(), f'{path}/checkpoint_{k_steps}k_steps_optim.pyt')
394397

395398
def log(self, path, msg):
396399
with open(path, 'a') as f:
@@ -405,10 +408,14 @@ def restore(self, path):
405408
self.load(path)
406409

407410
def load(self, path, device='cpu'):
408-
# because PyTorch places on CPU by default, we follow those semantics by using CPU as default.
411+
# because PyTorch places on CPU by default, we follow those semantics by
412+
# using CPU as default.
409413
self.load_state_dict(torch.load(path, map_location=device), strict=False)
410414

411415
def save(self, path):
416+
# No optimizer argument because saving a model should not include data
417+
# only relevant in the training process - it should only be properties
418+
# of the model itself. Let caller take care of saving optimzier state.
412419
torch.save(self.state_dict(), path)
413420

414421
def num_params(self, print_out=True):

models/tacotron.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,12 @@ def reset_step(self):
432432
# assignment to parameters or buffers is overloaded, updates internal dict entry
433433
self.step = torch.zeros(1, dtype=torch.long)
434434

435-
def checkpoint(self, path):
435+
def checkpoint(self, path, optimizer):
436+
# Optimizer can be given as an argument because checkpoint function is
437+
# only useful in context of already existing training process.
436438
k_steps = self.get_step() // 1000
437439
self.save(f'{path}/checkpoint_{k_steps}k_steps.pyt')
440+
torch.save(optimizer.get_state(), f'{path}/checkpoint_{k_steps}k_steps_optim.pyt')
438441

439442
def log(self, path, msg):
440443
with open(path, 'a') as f:
@@ -454,6 +457,9 @@ def load(self, path, device='cpu'):
454457
self.load_state_dict(torch.load(path, map_location=device), strict=False)
455458

456459
def save(self, path):
460+
# No optimizer argument because saving a model should not include data
461+
# only relevant in the training process - it should only be properties
462+
# of the model itself. Let caller take care of saving optimzier state.
457463
torch.save(self.state_dict(), path)
458464

459465
def num_params(self, print_out=True):

train_tacotron.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from utils.paths import Paths
99
from models.tacotron import Tacotron
1010
import argparse
11+
from utils import data_parallel_workaround
12+
import os
1113

1214

1315
def np_now(x): return x.detach().cpu().numpy()
@@ -61,7 +63,7 @@ def tts_train_loop(model, optimizer, train_set, lr, train_steps, attn_example):
6163
avg_loss = running_loss / i
6264

6365
if step % hp.tts_checkpoint_every == 0:
64-
model.checkpoint(paths.tts_checkpoints)
66+
model.checkpoint(paths.tts_checkpoints, optimizer)
6567

6668
if attn_example in ids:
6769
idx = ids.index(attn_example)
@@ -71,6 +73,9 @@ def tts_train_loop(model, optimizer, train_set, lr, train_steps, attn_example):
7173
msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | '
7274
stream(msg)
7375

76+
# Must save latest optimizer state to ensure that resuming training
77+
# doesn't produce artifacts
78+
torch.save(optimizer.state_dict(), paths.tts_latest_optim)
7479
model.save(paths.tts_latest_weights)
7580
model.log(paths.tts_log, msg)
7681
print(' ')
@@ -146,7 +151,10 @@ def create_gta_features(model, train_set, save_path):
146151

147152
# model.set_r(hp.tts_r)
148153

149-
optimiser = optim.Adam(model.parameters())
154+
optimizer = optim.Adam(model.parameters())
155+
if os.path.isfile(paths.tts_latest_optim):
156+
print(f'Loading Optimizer State: "{paths.tts_latest_optim}"')
157+
optimizer.load_state_dict(torch.load(paths.tts_latest_optim))
150158

151159
current_step = model.get_step()
152160

@@ -169,7 +177,7 @@ def create_gta_features(model, train_set, save_path):
169177
('Learning Rate', lr),
170178
('Outputs/Step (r)', model.get_r())])
171179

172-
tts_train_loop(model, optimiser, train_set, lr, training_steps, attn_example)
180+
tts_train_loop(model, optimizer, train_set, lr, training_steps, attn_example)
173181

174182
print('Training Complete.')
175183
print('To continue training increase tts_total_steps in hparams.py or use --force_train\n')

train_wavernn.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from utils.paths import Paths
1313
import argparse
1414
from utils import data_parallel_workaround
15+
import os
1516

1617

17-
def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_steps, device):
18+
def voc_train_loop(model, loss_func, optimizer, train_set, test_set, lr, total_steps, device):
1819

19-
for p in optimiser.param_groups: p['lr'] = lr
20+
for p in optimizer.param_groups: p['lr'] = lr
2021

2122
total_iters = len(train_set)
2223
epochs = (total_steps - model.get_step()) // total_iters + 1
@@ -46,13 +47,13 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
4647

4748
loss = loss_func(y_hat, y)
4849

49-
optimiser.zero_grad()
50+
optimizer.zero_grad()
5051
loss.backward()
5152
if hp.voc_clip_grad_norm is not None:
5253
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm)
5354
if np.isnan(grad_norm):
5455
print('grad_norm was NaN!')
55-
optimiser.step()
56+
optimizer.step()
5657
running_loss += loss.item()
5758

5859
speed = i / (time.time() - start)
@@ -64,11 +65,14 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
6465
if step % hp.voc_checkpoint_every == 0:
6566
gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
6667
hp.voc_target, hp.voc_overlap, paths.voc_output)
67-
model.checkpoint(paths.voc_checkpoints)
68+
model.checkpoint(paths.voc_checkpoints, optimizer)
6869

6970
msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
7071
stream(msg)
7172

73+
# Must save latest optimizer state to ensure that resuming training
74+
# doesn't produce artifacts
75+
torch.save(optimizer.state_dict(), paths.tts_latest_optim)
7276
model.save(paths.voc_latest_weights)
7377
model.log(paths.voc_log, msg)
7478
print(' ')
@@ -123,7 +127,10 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
123127

124128
voc_model.restore(paths.voc_latest_weights)
125129

126-
optimiser = optim.Adam(voc_model.parameters())
130+
optimizer = optim.Adam(voc_model.parameters())
131+
if os.path.isfile(paths.voc_latest_optim):
132+
print(f'Loading Optimizer State: "{paths.voc_latest_optim}"')
133+
optimizer.load_state_dict(torch.load(paths.voc_latest_optim))
127134

128135
train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)
129136

@@ -137,7 +144,7 @@ def voc_train_loop(model, loss_func, optimiser, train_set, test_set, lr, total_s
137144

138145
loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss
139146

140-
voc_train_loop(voc_model, loss_func, optimiser, train_set, test_set, lr, total_steps, device)
147+
voc_train_loop(voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps, device)
141148

142149
print('Training Complete.')
143150
print('To continue training increase voc_total_steps in hparams.py or use --force_train')

utils/paths.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ def __init__(self, data_path, voc_id, tts_id):
1111
# WaveRNN/Vocoder Paths
1212
self.voc_checkpoints = f'checkpoints/{voc_id}.wavernn/'
1313
self.voc_latest_weights = f'{self.voc_checkpoints}latest_weights.pyt'
14+
self.voc_latest_optim = f'{self.voc_checkpoints}latest_optim.pyt'
1415
self.voc_output = f'model_outputs/{voc_id}.wavernn/'
1516
self.voc_step = f'{self.voc_checkpoints}/step.npy'
1617
self.voc_log = f'{self.voc_checkpoints}log.txt'
1718
# Tactron/TTS Paths
1819
self.tts_checkpoints = f'checkpoints/{tts_id}.tacotron/'
1920
self.tts_latest_weights = f'{self.tts_checkpoints}latest_weights.pyt'
21+
self.tts_latest_optim = f'{self.tts_checkpoints}latest_optim.pyt'
2022
self.tts_output = f'model_outputs/{tts_id}.tts/'
2123
self.tts_step = f'{self.tts_checkpoints}/step.npy'
2224
self.tts_log = f'{self.tts_checkpoints}log.txt'

0 commit comments

Comments
 (0)