Skip to content

Commit 11517eb

Browse files
committed
save models on interrupt
1 parent 1ce6753 commit 11517eb

File tree

1 file changed

+30
-21
lines changed

1 file changed

+30
-21
lines changed

train.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,44 @@ def train(args):
2626
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
2727
valdataloader.update(**valargs)
2828
device = args.device
29-
3029
model = get_model(args)
3130
if args.load_chkpt is not None:
3231
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
3332
encoder, decoder = model.encoder, model.decoder
33+
34+
def save_models(e):
35+
torch.save(model.state_dict(), os.path.join(args.out_path, '%s_e%02d.pth' % (args.name, e+1)))
36+
yaml.dump(dict(args), open(os.path.join(args.out_path, 'config.yaml'), 'w+'))
37+
3438
opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas)
3539
scheduler = get_scheduler(args.scheduler)(opt, max_lr=args.max_lr, steps_per_epoch=len(dataloader)*2, epochs=args.epochs) # scheduler steps are weird.
40+
try:
41+
for e in range(args.epoch, args.epochs):
42+
args.epoch = e
43+
dset = tqdm(iter(dataloader))
44+
for i, (seq, im) in enumerate(dset):
45+
opt.zero_grad()
46+
tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device)
47+
encoded = encoder(im.to(device))
48+
loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
49+
loss.backward()
50+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
51+
opt.step()
52+
scheduler.step()
3653

37-
for e in range(args.epoch, args.epochs):
38-
args.epoch = e
39-
dset = tqdm(iter(dataloader))
40-
for i, (seq, im) in enumerate(dset):
41-
opt.zero_grad()
42-
tgt_seq, tgt_mask = seq['input_ids'].to(device), seq['attention_mask'].bool().to(device)
43-
encoded = encoder(im.to(device))
44-
loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
45-
loss.backward()
46-
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
47-
opt.step()
48-
scheduler.step()
49-
50-
dset.set_description('Loss: %.4f' % loss.item())
54+
dset.set_description('Loss: %.4f' % loss.item())
55+
if args.wandb:
56+
wandb.log({'train/loss': loss.item()})
57+
if (i+1) % args.sample_freq == 0:
58+
evaluate(model, valdataloader, args, num_batches=args.valbatches, name='val')
59+
if (e+1) % args.save_freq == 0:
60+
save_models(e)
5161
if args.wandb:
52-
wandb.log({'train/loss': loss.item()})
53-
if (i+1) % args.sample_freq == 0:
54-
evaluate(model, valdataloader, args, num_batches=args.valbatches, name='val')
55-
if (e+1) % args.save_freq == 0:
56-
torch.save(model.state_dict(), os.path.join(args.out_path, '%s_e%02d.pth' % (args.name, e+1)))
57-
yaml.dump(dict(args), open(os.path.join(args.out_path, 'config.yaml'), 'w+'))
62+
wandb.log({'train/epoch': e+1})
63+
except KeyboardInterrupt:
64+
if e > 2:
65+
save_models(e)
66+
raise KeyboardInterrupt
5867

5968

6069
if __name__ == '__main__':

0 commit comments

Comments
 (0)