1+ import os
2+ import tools .utils as utils
3+
4+ import torch
5+ import torch .nn as nn
6+
7+ class CheckPoint (object ):
8+ """
9+ save model state to file
10+ check_point_params: model, optimizer, epoch
11+ """
12+
13+ def __init__ (self , save_path ):
14+
15+ self .save_path = os .path .join (save_path , "check_point" )
16+ self .check_point_params = {'model' : None ,
17+ 'optimizer' : None ,
18+ 'epoch' : None }
19+
20+ # make directory
21+ if not os .path .isdir (self .save_path ):
22+ os .makedirs (self .save_path )
23+
24+ def load_state (self , model , state_dict ):
25+ """
26+ load state_dict to model
27+ :params model:
28+ :params state_dict:
29+ :return: model
30+ """
31+ model .eval ()
32+ model_dict = model .state_dict ()
33+
34+ for key , value in list (state_dict .items ()):
35+ if key in list (model_dict .keys ()):
36+ # print key, value.size()
37+ model_dict [key ] = value
38+ else :
39+ pass
40+ # print "key error:", key, value.size()
41+ model .load_state_dict (model_dict )
42+ # model.load_state_dict(state_dict)
43+ # set the model in evaluation mode, otherwise the accuracy will change
44+ # model.eval()
45+ # model.load_state_dict(state_dict)
46+
47+ return model
48+
49+ def load_model (self , model_path ):
50+ """
51+ load model
52+ :params model_path: path to the model
53+ :return: model_state_dict
54+ """
55+ if os .path .isfile (model_path ):
56+ print ("|===>Load retrain model from:" , model_path )
57+ # model_state_dict = torch.load(model_path, map_location={'cuda:1':'cuda:0'})
58+ model_state_dict = torch .load (model_path , map_location = 'cpu' )
59+ return model_state_dict
60+ else :
61+ assert False , "file not exits, model path: " + model_path
62+
63+ def load_checkpoint (self , checkpoint_path ):
64+ """
65+ load checkpoint file
66+ :params checkpoint_path: path to the checkpoint file
67+ :return: model_state_dict, optimizer_state_dict, epoch
68+ """
69+ if os .path .isfile (checkpoint_path ):
70+ print ("|===>Load resume check-point from:" , checkpoint_path )
71+ self .check_point_params = torch .load (checkpoint_path )
72+ model_state_dict = self .check_point_params ['model' ]
73+ optimizer_state_dict = self .check_point_params ['optimizer' ]
74+ epoch = self .check_point_params ['epoch' ]
75+ return model_state_dict , optimizer_state_dict , epoch
76+ else :
77+ assert False , "file not exits" + checkpoint_path
78+
79+ def save_checkpoint (self , model , optimizer , epoch , index = 0 ):
80+ """
81+ :params model: model
82+ :params optimizer: optimizer
83+ :params epoch: training epoch
84+ :params index: index of saved file, default: 0
85+ Note: if we add hook to the grad by using register_hook(hook), then the hook function
86+ can not be saved so we need to save state_dict() only. Although save state dictionary
87+ is recommended, some times we still need to save the whole model as it can save all
88+ the information of the trained model, and we do not need to create a new network in
89+ next time. However, the GPU information will be saved too, which leads to some issues
90+ when we use the model on different machine
91+ """
92+
93+ # get state_dict from model and optimizer
94+ model = self .list2sequential (model )
95+ if isinstance (model , nn .DataParallel ):
96+ model = model .module
97+ model = model .state_dict ()
98+ optimizer = optimizer .state_dict ()
99+
100+ # save information to a dict
101+ self .check_point_params ['model' ] = model
102+ self .check_point_params ['optimizer' ] = optimizer
103+ self .check_point_params ['epoch' ] = epoch
104+
105+ # save to file
106+ torch .save (self .check_point_params , os .path .join (
107+ self .save_path , "checkpoint_%03d.pth" % index ))
108+
109+ def list2sequential (self , model ):
110+ if isinstance (model , list ):
111+ model = nn .Sequential (* model )
112+ return model
113+
114+ def save_model (self , model , best_flag = False , index = 0 , tag = "" ):
115+ """
116+ :params model: model to save
117+ :params best_flag: if True, the saved model is the one that gets best performance
118+ """
119+ # get state dict
120+ model = self .list2sequential (model )
121+ if isinstance (model , nn .DataParallel ):
122+ model = model .module
123+ model = model .state_dict ()
124+ if best_flag :
125+ if tag != "" :
126+ torch .save (model , os .path .join (self .save_path , "%s_best_model.pth" % tag ))
127+ else :
128+ torch .save (model , os .path .join (self .save_path , "best_model.pth" ))
129+ else :
130+ if tag != "" :
131+ torch .save (model , os .path .join (self .save_path , "%s_model_%03d.pth" % (tag , index )))
132+ else :
133+ torch .save (model , os .path .join (self .save_path , "model_%03d.pth" % index ))
0 commit comments