@@ -26,35 +26,44 @@ def train(args):
26
26
valargs .update (batchsize = args .testbatchsize , keep_smaller_batches = True , test = True )
27
27
valdataloader .update (** valargs )
28
28
device = args .device
29
-
30
29
model = get_model (args )
31
30
if args .load_chkpt is not None :
32
31
model .load_state_dict (torch .load (args .load_chkpt , map_location = device ))
33
32
encoder , decoder = model .encoder , model .decoder
33
+
34
+ def save_models (e ):
35
+ torch .save (model .state_dict (), os .path .join (args .out_path , '%s_e%02d.pth' % (args .name , e + 1 )))
36
+ yaml .dump (dict (args ), open (os .path .join (args .out_path , 'config.yaml' ), 'w+' ))
37
+
34
38
opt = get_optimizer (args .optimizer )(model .parameters (), args .lr , betas = args .betas )
35
39
scheduler = get_scheduler (args .scheduler )(opt , max_lr = args .max_lr , steps_per_epoch = len (dataloader )* 2 , epochs = args .epochs ) # scheduler steps are weird.
40
+ try :
41
+ for e in range (args .epoch , args .epochs ):
42
+ args .epoch = e
43
+ dset = tqdm (iter (dataloader ))
44
+ for i , (seq , im ) in enumerate (dset ):
45
+ opt .zero_grad ()
46
+ tgt_seq , tgt_mask = seq ['input_ids' ].to (device ), seq ['attention_mask' ].bool ().to (device )
47
+ encoded = encoder (im .to (device ))
48
+ loss = decoder (tgt_seq , mask = tgt_mask , context = encoded )
49
+ loss .backward ()
50
+ torch .nn .utils .clip_grad_norm_ (model .parameters (), 1 )
51
+ opt .step ()
52
+ scheduler .step ()
36
53
37
- for e in range (args .epoch , args .epochs ):
38
- args .epoch = e
39
- dset = tqdm (iter (dataloader ))
40
- for i , (seq , im ) in enumerate (dset ):
41
- opt .zero_grad ()
42
- tgt_seq , tgt_mask = seq ['input_ids' ].to (device ), seq ['attention_mask' ].bool ().to (device )
43
- encoded = encoder (im .to (device ))
44
- loss = decoder (tgt_seq , mask = tgt_mask , context = encoded )
45
- loss .backward ()
46
- torch .nn .utils .clip_grad_norm_ (model .parameters (), 1 )
47
- opt .step ()
48
- scheduler .step ()
49
-
50
- dset .set_description ('Loss: %.4f' % loss .item ())
54
+ dset .set_description ('Loss: %.4f' % loss .item ())
55
+ if args .wandb :
56
+ wandb .log ({'train/loss' : loss .item ()})
57
+ if (i + 1 ) % args .sample_freq == 0 :
58
+ evaluate (model , valdataloader , args , num_batches = args .valbatches , name = 'val' )
59
+ if (e + 1 ) % args .save_freq == 0 :
60
+ save_models (e )
51
61
if args .wandb :
52
- wandb .log ({'train/loss' : loss .item ()})
53
- if (i + 1 ) % args .sample_freq == 0 :
54
- evaluate (model , valdataloader , args , num_batches = args .valbatches , name = 'val' )
55
- if (e + 1 ) % args .save_freq == 0 :
56
- torch .save (model .state_dict (), os .path .join (args .out_path , '%s_e%02d.pth' % (args .name , e + 1 )))
57
- yaml .dump (dict (args ), open (os .path .join (args .out_path , 'config.yaml' ), 'w+' ))
62
+ wandb .log ({'train/epoch' : e + 1 })
63
+ except KeyboardInterrupt :
64
+ if e > 2 :
65
+ save_models (e )
66
+ raise KeyboardInterrupt
58
67
59
68
60
69
if __name__ == '__main__' :
0 commit comments