Skip to content

Commit 78a8e50

Browse files
committed
添加train.py
1 parent dffaa3a commit 78a8e50

17 files changed

+6364
-0
lines changed

10. ssd/VOCdevkit/VOC2007/voc2ssd.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
function: Get test.txt and trainval.txt
3+
"""
4+
import os
5+
import random
6+
7+
# 路径
8+
xmlfilepath = 'Annotations'
9+
saveBasePath = "ImageSets"
10+
11+
train_percent = 0.9
12+
13+
# 返回指定路径下的文件列表
14+
temp_xml = os.listdir(xmlfilepath)
15+
16+
# 存放xmlfilepath路径下所有的xml文件
17+
total_xml = []
18+
for xml in temp_xml:
19+
# endswith() 方法用于判断字符串是否以指定后缀结尾
20+
if xml.endswith(".xml"):
21+
total_xml.append(xml)
22+
23+
# 数据集总容量
24+
num = len(total_xml)
25+
train = int(num * train_percent)
26+
27+
list = range(num)
28+
# 用于截取列表的指定长度的随机数,但是不会改变列表本身的排序
29+
train = random.sample(list, train)
30+
31+
ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
32+
ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
33+
34+
for i in list:
35+
name = total_xml[i][:-4]+'\n' # 取图片前面的序号
36+
if i in train:
37+
ftrain.write(name)
38+
else:
39+
ftest.write(name)
40+
41+
ftrain.close()
42+
ftest .close()

10. ssd/detect.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
import colorsys
3+
import numpy as np
4+
import tensorflow as tf
5+
from networks import ssd
6+
from PIL import ImageFont, ImageDraw
7+
from tensorflow.keras.applications.imagenet_utils import preprocess_input
8+
from networks.utils import BBoxUtility, letterbox_image, ssd_correct_boxes
9+
10+
class SSD(object):
11+
# --------------------------------------------#
12+
# 使用自己训练好的模型预测需要修改2个参数
13+
# model_path和classes_path都需要修改!
14+
# --------------------------------------------#
15+
_defaults = {
16+
"model_path": 'weights/ep066-loss3.277-val_loss3.753.h5',
17+
"classes_path": 'files/voc_classes.txt',
18+
"model_image_size": (300, 300, 3),
19+
"confidence": 0.5,
20+
}
21+
22+
@classmethod
23+
def get_defaults(cls, n):
24+
if n in cls._defaults:
25+
return cls._defaults[n]
26+
else:
27+
return "Unrecognized attribute name '" + n + "'"
28+
29+
# ---------------------------------------------------#
30+
# 初始化ssd
31+
# ---------------------------------------------------#
32+
def __init__(self, **kwargs):
33+
self.__dict__.update(self._defaults)
34+
self.class_names = self._get_class()
35+
self.generate()
36+
self.bbox_util = BBoxUtility(self.num_classes)
37+
38+
# ---------------------------------------------------#
39+
# 获得所有的分类
40+
# ---------------------------------------------------#
41+
def _get_class(self):
42+
classes_path = os.path.expanduser(self.classes_path)
43+
with open(classes_path) as f:
44+
class_names = f.readlines()
45+
class_names = [c.strip() for c in class_names]
46+
return class_names
47+
48+
# ---------------------------------------------------#
49+
# 获得所有的分类
50+
# ---------------------------------------------------#
51+
def generate(self):
52+
model_path = os.path.expanduser(self.model_path)
53+
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
54+
55+
# 计算总的种类
56+
self.num_classes = len(self.class_names) + 1
57+
58+
# 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
59+
# 否则先构建模型再载入
60+
61+
self.ssd_model = ssd.SSD300(self.model_image_size, self.num_classes)
62+
self.ssd_model.load_weights(self.model_path, by_name=True)
63+
64+
self.ssd_model.summary()
65+
print('{} model, anchors, and classes loaded.'.format(model_path))
66+
67+
# 画框设置不同的颜色
68+
hsv_tuples = [(x / len(self.class_names), 1., 1.)
69+
for x in range(len(self.class_names))]
70+
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
71+
self.colors = list(
72+
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
73+
self.colors))
74+
75+
@tf.function
76+
def get_pred(self, photo):
77+
preds = self.ssd_model(photo, training=False)
78+
return preds
79+
80+
# ---------------------------------------------------#
81+
# 检测图片
82+
# ---------------------------------------------------#
83+
def detect_image(self, image):
84+
image_shape = np.array(np.shape(image)[0:2])
85+
crop_img, x_offset, y_offset = letterbox_image(image, (self.model_image_size[0], self.model_image_size[1]))
86+
photo = np.array(crop_img, dtype=np.float64)
87+
88+
# 图片预处理,归一化
89+
photo = preprocess_input(np.reshape(photo, [1, self.model_image_size[0], self.model_image_size[1], 3]))
90+
91+
# 前向计算
92+
preds = self.get_pred(photo).numpy()
93+
94+
# 将预测结果进行解码
95+
results = self.bbox_util.detection_out(preds, confidence_threshold=self.confidence)
96+
97+
if len(results[0]) <= 0:
98+
return image
99+
100+
# 筛选出其中得分高于confidence的框
101+
det_label = results[0][:, 0]
102+
det_conf = results[0][:, 1]
103+
det_xmin, det_ymin, det_xmax, det_ymax = results[0][:, 2], results[0][:, 3], results[0][:, 4], results[0][:, 5]
104+
top_indices = [i for i, conf in enumerate(det_conf) if conf >= self.confidence]
105+
top_conf = det_conf[top_indices]
106+
top_label_indices = det_label[top_indices].tolist()
107+
top_xmin, top_ymin, top_xmax, top_ymax = np.expand_dims(det_xmin[top_indices], -1), np.expand_dims(
108+
det_ymin[top_indices], -1), np.expand_dims(det_xmax[top_indices], -1), np.expand_dims(det_ymax[top_indices],
109+
-1)
110+
111+
# 去掉灰条
112+
boxes = ssd_correct_boxes(top_ymin, top_xmin, top_ymax, top_xmax,
113+
np.array([self.model_image_size[0], self.model_image_size[1]]), image_shape)
114+
115+
font = ImageFont.truetype(font='files/simhei.ttf',
116+
size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
117+
118+
thickness = (np.shape(image)[0] + np.shape(image)[1]) // self.model_image_size[0]
119+
120+
for i, c in enumerate(top_label_indices):
121+
predicted_class = self.class_names[int(c) - 1]
122+
score = top_conf[i]
123+
124+
top, left, bottom, right = boxes[i]
125+
top = top - 5
126+
left = left - 5
127+
bottom = bottom + 5
128+
right = right + 5
129+
130+
top = max(0, np.floor(top + 0.5).astype('int32'))
131+
left = max(0, np.floor(left + 0.5).astype('int32'))
132+
bottom = min(np.shape(image)[0], np.floor(bottom + 0.5).astype('int32'))
133+
right = min(np.shape(image)[1], np.floor(right + 0.5).astype('int32'))
134+
135+
# 画框框
136+
label = '{} {:.2f}'.format(predicted_class, score)
137+
draw = ImageDraw.Draw(image)
138+
label_size = draw.textsize(label, font)
139+
label = label.encode('utf-8')
140+
print(label)
141+
142+
if top - label_size[1] >= 0:
143+
text_origin = np.array([left, top - label_size[1]])
144+
else:
145+
text_origin = np.array([left, top + 1])
146+
147+
for i in range(thickness):
148+
draw.rectangle(
149+
[left + i, top + i, right - i, bottom - i],
150+
outline=self.colors[int(c) - 1])
151+
draw.rectangle(
152+
[tuple(text_origin), tuple(text_origin + label_size)],
153+
fill=self.colors[int(c) - 1])
154+
draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font)
155+
del draw
156+
return image
157+

0 commit comments

Comments
 (0)