Skip to content

Commit 3023e2e

Browse files
author
xiezheng
committed
init
0 parents  commit 3023e2e

File tree

167 files changed

+25910
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

167 files changed

+25910
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.idea/
2+
results/
3+
*.pyc
4+
.vscode/

annotations/landmark_imagelist.txt

Lines changed: 10000 additions & 0 deletions
Large diffs are not rendered by default.

annotations/wider_origin_anno.txt

Lines changed: 12880 additions & 0 deletions
Large diffs are not rendered by default.

checkpoint.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
import tools.utils as utils
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
class CheckPoint(object):
8+
"""
9+
save model state to file
10+
check_point_params: model, optimizer, epoch
11+
"""
12+
13+
def __init__(self, save_path):
14+
15+
self.save_path = os.path.join(save_path, "check_point")
16+
self.check_point_params = {'model': None,
17+
'optimizer': None,
18+
'epoch': None}
19+
20+
# make directory
21+
if not os.path.isdir(self.save_path):
22+
os.makedirs(self.save_path)
23+
24+
def load_state(self, model, state_dict):
25+
"""
26+
load state_dict to model
27+
:params model:
28+
:params state_dict:
29+
:return: model
30+
"""
31+
model.eval()
32+
model_dict = model.state_dict()
33+
34+
for key, value in list(state_dict.items()):
35+
if key in list(model_dict.keys()):
36+
# print key, value.size()
37+
model_dict[key] = value
38+
else:
39+
pass
40+
# print "key error:", key, value.size()
41+
model.load_state_dict(model_dict)
42+
# model.load_state_dict(state_dict)
43+
# set the model in evaluation mode, otherwise the accuracy will change
44+
# model.eval()
45+
# model.load_state_dict(state_dict)
46+
47+
return model
48+
49+
def load_model(self, model_path):
50+
"""
51+
load model
52+
:params model_path: path to the model
53+
:return: model_state_dict
54+
"""
55+
if os.path.isfile(model_path):
56+
print("|===>Load retrain model from:", model_path)
57+
# model_state_dict = torch.load(model_path, map_location={'cuda:1':'cuda:0'})
58+
model_state_dict = torch.load(model_path, map_location='cpu')
59+
return model_state_dict
60+
else:
61+
assert False, "file not exits, model path: " + model_path
62+
63+
def load_checkpoint(self, checkpoint_path):
64+
"""
65+
load checkpoint file
66+
:params checkpoint_path: path to the checkpoint file
67+
:return: model_state_dict, optimizer_state_dict, epoch
68+
"""
69+
if os.path.isfile(checkpoint_path):
70+
print("|===>Load resume check-point from:", checkpoint_path)
71+
self.check_point_params = torch.load(checkpoint_path)
72+
model_state_dict = self.check_point_params['model']
73+
optimizer_state_dict = self.check_point_params['optimizer']
74+
epoch = self.check_point_params['epoch']
75+
return model_state_dict, optimizer_state_dict, epoch
76+
else:
77+
assert False, "file not exits" + checkpoint_path
78+
79+
def save_checkpoint(self, model, optimizer, epoch, index=0):
80+
"""
81+
:params model: model
82+
:params optimizer: optimizer
83+
:params epoch: training epoch
84+
:params index: index of saved file, default: 0
85+
Note: if we add hook to the grad by using register_hook(hook), then the hook function
86+
can not be saved so we need to save state_dict() only. Although save state dictionary
87+
is recommended, some times we still need to save the whole model as it can save all
88+
the information of the trained model, and we do not need to create a new network in
89+
next time. However, the GPU information will be saved too, which leads to some issues
90+
when we use the model on different machine
91+
"""
92+
93+
# get state_dict from model and optimizer
94+
model = self.list2sequential(model)
95+
if isinstance(model, nn.DataParallel):
96+
model = model.module
97+
model = model.state_dict()
98+
optimizer = optimizer.state_dict()
99+
100+
# save information to a dict
101+
self.check_point_params['model'] = model
102+
self.check_point_params['optimizer'] = optimizer
103+
self.check_point_params['epoch'] = epoch
104+
105+
# save to file
106+
torch.save(self.check_point_params, os.path.join(
107+
self.save_path, "checkpoint_%03d.pth" % index))
108+
109+
def list2sequential(self, model):
110+
if isinstance(model, list):
111+
model = nn.Sequential(*model)
112+
return model
113+
114+
def save_model(self, model, best_flag=False, index=0, tag=""):
115+
"""
116+
:params model: model to save
117+
:params best_flag: if True, the saved model is the one that gets best performance
118+
"""
119+
# get state dict
120+
model = self.list2sequential(model)
121+
if isinstance(model, nn.DataParallel):
122+
model = model.module
123+
model = model.state_dict()
124+
if best_flag:
125+
if tag != "":
126+
torch.save(model, os.path.join(self.save_path, "%s_best_model.pth"%tag))
127+
else:
128+
torch.save(model, os.path.join(self.save_path, "best_model.pth"))
129+
else:
130+
if tag != "":
131+
torch.save(model, os.path.join(self.save_path, "%s_model_%03d.pth" % (tag, index)))
132+
else:
133+
torch.save(model, os.path.join(self.save_path, "model_%03d.pth" % index))

config.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
3+
4+
MODEL_STORE_DIR = "./models"
5+
6+
ANNO_STORE_DIR = "./annotations"
7+
8+
TRAIN_DATA_DIR = "./data"
9+
10+
LOG_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))+"/log"
11+
12+
USE_CUDA = True
13+
14+
TRAIN_BATCH_SIZE = 512
15+
16+
TRAIN_LR = 0.01
17+
18+
END_EPOCH = 10
19+
20+
21+
PNET_POSTIVE_ANNO_FILENAME = "pos_12.txt"
22+
PNET_NEGATIVE_ANNO_FILENAME = "neg_12.txt"
23+
PNET_PART_ANNO_FILENAME = "part_12.txt"
24+
PNET_LANDMARK_ANNO_FILENAME = "landmark_12.txt"
25+
26+
27+
RNET_POSTIVE_ANNO_FILENAME = "pos_24.txt"
28+
RNET_NEGATIVE_ANNO_FILENAME = "neg_24.txt"
29+
RNET_PART_ANNO_FILENAME = "part_24.txt"
30+
RNET_LANDMARK_ANNO_FILENAME = "landmark_24.txt"
31+
32+
33+
ONET_POSTIVE_ANNO_FILENAME = "pos_48.txt"
34+
ONET_NEGATIVE_ANNO_FILENAME = "neg_48.txt"
35+
ONET_PART_ANNO_FILENAME = "part_48.txt"
36+
ONET_LANDMARK_ANNO_FILENAME = "landmark_48.txt"
37+
38+
PNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_12.txt"
39+
RNET_TRAIN_IMGLIST_FILENAME = "imglist_anno_24.txt"
40+
ONET_TRAIN_IMGLIST_FILENAME = "imglist_anno_48.txt"

data/test_images/img_12883.jpg

17.7 KB

data/test_images/img_12884.jpg

23.1 KB

data/test_images/img_12903.jpg

10.6 KB

data/test_images/img_12934.jpg

12 KB

data/test_images/img_12936.jpg

15.5 KB

0 commit comments

Comments
 (0)