1414# limitations under the License.
1515
1616import os
17- import warnings
1817
1918import torch
20- from torch .nn import DataParallel as DP
21- from torch .nn .parallel import DistributedDataParallel as DDP
2219
20+ from simrec .utils .distributed import is_main_process
2321
24- def save_ckpt (net , optimizer ,scheduler , misc , __C ):
25- path = __C .CKPTs_PATH
26- if not os .path .exists (path ):
27- os .mkdir (path )
28- path += '/' + __C .VERSION
29- if not os .path .exists (path ):
30- os .mkdir (path )
31- assert isinstance (misc , dict )
32- if isinstance (net , DP ) or isinstance (net , DDP ):
33- path += '/' + 'dist_'
34- path += str (misc ['epoch' ]) + '.pth.tar'
35- ckpt = {
36- 'net_state_dict' : net .state_dict (),
37- 'optimizer_state_dict' : optimizer .state_dict (),
38- 'scheduler' :scheduler .state_dict (),
39- 'epoch' :misc ['epoch' ],
40- 'lr' :optimizer .param_groups [0 ]["lr" ],
41- }
42- torch .save (ckpt , path )
43-
44-
45- def load_ckpt (net , optimizer ,scheduler , path , rank = None ):
46- loc = f'cuda:{ rank } ' if rank is not None else None
47- ckpt = torch .load (path , map_location = loc )
48-
49- flag = isinstance (net , DP ) or isinstance (net , DDP )
50- if '_dist' in path :
51- if not flag :
52- for name in ckpt ['net_state_dict' ]:
53- assert name .startswith ('module.' )
54- ckpt ['net_state_dict' ][name .lstrip ('module.' )] = ckpt ['net_state_dict' ].pop (name )
55- else :
56- if flag :
57- for name in ckpt ['net_state_dict' ]:
58- ckpt ['net_state_dict' ]['module.' + name ] = ckpt ['net_state_dict' ].pop (name )
59-
60- optimizer .load_state_dict (ckpt ['optimizer_state_dict' ])
6122
62- scheduler .load_state_dict (ckpt ['scheduler' ])
63-
64- missing , unexpected = net .load_state_dict (ckpt ['net_state_dict' ], strict = False )
65- if unexpected .__len__ != 0 :
66- warnings .warn (f'Current model misses { unexpected .__len__ } parameters from checkpointing model' )
67- for name in missing :
68- print ('\n ' + name + '\n ' )
69- if missing .__len__ != 0 :
70- warnings .warn (f'Current model contains { missing .__len__ } parameters that checkpointing model doesn\' t contain' )
71- for name in unexpected :
72- print ('\n ' + name + '\n ' )
73-
74- return ckpt
23+ def load_checkpoint (cfg , model , optimizer , scheduler , logger ):
24+ logger .info (f"==============> Resuming form { cfg .train .resume_path } ...................." )
25+ checkpoint = torch .load (cfg .train .resume_path , map_location = lambda storage , loc : storage .cuda ())
26+ msg = model .load_state_dict (checkpoint ['state_dict' ], strict = False )
27+ logger .info (msg )
28+ optimizer .load_state_dict (checkpoint ["optimizer" ])
29+ scheduler .load_state_dict (checkpoint ["scheduler" ])
30+ start_epoch = checkpoint ["epoch" ]
31+ logger .info ("==> loaded checkpoint from {}\n " .format (cfg .train .resume_path ) +
32+ "==> epoch: {} lr: {} " .format (checkpoint ['epoch' ],checkpoint ['lr' ]))
33+ return start_epoch + 1
7534
7635
7736def save_checkpoint (cfg , epoch , model , optimizer , scheduler , logger , det_best = False , seg_best = False ):
@@ -99,4 +58,16 @@ def save_checkpoint(cfg, epoch, model, optimizer, scheduler, logger, det_best=Fa
9958 if seg_best :
10059 seg_best_model_path = os .path .join (cfg .train .output_dir , f'seg_best_model.pth' )
10160 torch .save (save_state , seg_best_model_path )
102- logger .info (f"checkpoints saved !!!" )
61+ logger .info (f"checkpoints saved !!!" )
62+
63+
64+ def auto_resume_helper (output_dir ):
65+ checkpoints = os .listdir (output_dir )
66+ checkpoints = [ckpt for ckpt in checkpoints if ckpt .endswith ('pth' )]
67+ print (f"All checkpoints founded in { output_dir } : { checkpoints } " )
68+ if len (checkpoints ) > 0 :
69+ resume_file = os .path .join (output_dir , "last_checkpoint.pth" )
70+ else :
71+ resume_file = None
72+
73+ return resume_file
0 commit comments