Skip to content

Commit 5c462c7

Browse files
committed
update eval when train
1 parent 9057909 commit 5c462c7

File tree

6 files changed

+275
-37
lines changed

6 files changed

+275
-37
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Medical_Datasets/
77
lfw/
88
logs/
99
model_data/
10+
.temp_map_out/
1011

1112
# Byte-compiled / optimized / DLL files
1213
__pycache__/

get_map.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
if __name__ == "__main__":
1717
'''
18-
Recall和Precision不像AP是一个面积的概念,在门限值不同时,网络的Recall和Precision值是不同的。
19-
map计算结果中的Recall和Precision代表的是当预测时,门限置信度为0.5时,所对应的Recall和Precision值。
18+
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
19+
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
2020
21-
此处获得的./map_out/detection-results/里面的txt的框的数量会比直接predict多一些,这是因为这里的门限低,
22-
目的是为了计算不同门限条件下的Recall和Precision值,从而实现map的计算。
21+
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
22+
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
2323
'''
2424
#------------------------------------------------------------------------------------------------------------------#
2525
# map_mode用于指定该文件运行时计算的内容
@@ -30,16 +30,41 @@
3030
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
3131
#-------------------------------------------------------------------------------------------------------------------#
3232
map_mode = 0
33-
#-------------------------------------------------------#
33+
#--------------------------------------------------------------------------------------#
3434
# 此处的classes_path用于指定需要测量VOC_map的类别
3535
# 一般情况下与训练和预测所用的classes_path一致即可
36-
#-------------------------------------------------------#
36+
#--------------------------------------------------------------------------------------#
3737
classes_path = 'model_data/voc_classes.txt'
38-
#-------------------------------------------------------#
39-
# MINOVERLAP用于指定想要获得的mAP0.x
38+
#--------------------------------------------------------------------------------------#
39+
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
4040
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
41-
#-------------------------------------------------------#
41+
#
42+
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
43+
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
44+
#--------------------------------------------------------------------------------------#
4245
MINOVERLAP = 0.5
46+
#--------------------------------------------------------------------------------------#
47+
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
48+
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
49+
#
50+
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
51+
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
52+
#--------------------------------------------------------------------------------------#
53+
confidence = 0.001
54+
#--------------------------------------------------------------------------------------#
55+
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
56+
#
57+
# 该值一般不调整。
58+
#--------------------------------------------------------------------------------------#
59+
nms_iou = 0.5
60+
#---------------------------------------------------------------------------------------------------------------#
61+
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
62+
#
63+
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
64+
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
65+
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
66+
#---------------------------------------------------------------------------------------------------------------#
67+
score_threhold = 0.5
4368
#-------------------------------------------------------#
4469
# map_vis用于指定是否开启VOC_map计算的可视化
4570
#-------------------------------------------------------#
@@ -69,7 +94,7 @@
6994

7095
if map_mode == 0 or map_mode == 1:
7196
print("Load model.")
72-
yolo = YOLO(confidence = 0.001, nms_iou = 0.5)
97+
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
7398
print("Load model done.")
7499

75100
print("Get predict result.")
@@ -109,7 +134,7 @@
109134

110135
if map_mode == 0 or map_mode == 3:
111136
print("Get map.")
112-
get_map(MINOVERLAP, True, path = map_out_path)
137+
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
113138
print("Get map done.")
114139

115140
if map_mode == 4:

train.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from nets.yolo import get_train_model, yolo_body
1212
from nets.yolo_training import get_lr_scheduler
13-
from utils.callbacks import LossHistory, ModelCheckpoint
13+
from utils.callbacks import EvalCallback, LossHistory, ModelCheckpoint
1414
from utils.dataloader import YoloDatasets
1515
from utils.utils import get_anchors, get_classes, show_config
1616
from utils.utils_fit import fit_one_epoch
@@ -204,6 +204,17 @@
204204
#------------------------------------------------------------------#
205205
save_dir = 'logs'
206206
#------------------------------------------------------------------#
207+
# eval_flag 是否在训练时进行评估,评估对象为验证集
208+
# 安装pycocotools库后,评估体验更佳。
209+
# eval_period 代表多少个epoch评估一次,不建议频繁的评估
210+
# 评估需要消耗较多的时间,频繁评估会导致训练非常慢
211+
# 此处获得的mAP会与get_map.py获得的会有所不同,原因有二:
212+
# (一)此处获得的mAP为验证集的mAP。
213+
# (二)此处设置评估参数较为保守,目的是加快评估速度。
214+
#------------------------------------------------------------------#
215+
eval_flag = True
216+
eval_period = 10
217+
#------------------------------------------------------------------#
207218
# num_workers 用于设置是否使用多线程读取数据,1代表关闭多线程
208219
# 开启后会加快数据读取速度,但是会占用更多内存
209220
# keras里开启多线程有些时候速度反而慢了许多
@@ -385,6 +396,8 @@
385396
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
386397
log_dir = os.path.join(save_dir, "loss_" + str(time_str))
387398
loss_history = LossHistory(log_dir)
399+
eval_callback = EvalCallback(model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, \
400+
eval_flag=eval_flag, period=eval_period)
388401
#---------------------------------------#
389402
# 开始模型训练
390403
#---------------------------------------#
@@ -436,7 +449,7 @@
436449
lr = lr_scheduler_func(epoch)
437450
K.set_value(optimizer.lr, lr)
438451

439-
fit_one_epoch(model_body, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val,
452+
fit_one_epoch(model_body, loss_history, eval_callback, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val,
440453
end_epoch, input_shape, anchors, anchors_mask, num_classes, label_smoothing, save_period, save_dir, strategy)
441454

442455
train_dataloader.on_epoch_end()
@@ -469,7 +482,9 @@
469482
monitor = 'val_loss', save_weights_only = True, save_best_only = True, period = 1)
470483
early_stopping = EarlyStopping(monitor='val_loss', min_delta = 0, patience = 10, verbose = 1)
471484
lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1)
472-
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler]
485+
eval_callback = EvalCallback(model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, \
486+
eval_flag=eval_flag, period=eval_period)
487+
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler, eval_callback]
473488

474489
if start_epoch < end_epoch:
475490
print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))
@@ -506,7 +521,7 @@
506521
#---------------------------------------#
507522
lr_scheduler_func = get_lr_scheduler(lr_decay_type, Init_lr_fit, Min_lr_fit, UnFreeze_Epoch)
508523
lr_scheduler = LearningRateScheduler(lr_scheduler_func, verbose = 1)
509-
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler]
524+
callbacks = [logging, loss_history, checkpoint, checkpoint_last, checkpoint_best, lr_scheduler, eval_callback]
510525

511526
for i in range(len(model_body.layers)):
512527
model_body.layers[i].trainable = True

utils/callbacks.py

Lines changed: 176 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1-
import math
21
import os
2+
import math
33
import warnings
44

55
import matplotlib
66
matplotlib.use('Agg')
77
from matplotlib import pyplot as plt
8-
import numpy as np
98
import scipy.signal
9+
10+
import shutil
11+
import numpy as np
12+
import tensorflow as tf
13+
1014
from tensorflow import keras
1115
from tensorflow.keras import backend as K
16+
from tensorflow.keras.layers import Input, Lambda
17+
from tensorflow.keras.models import Model
18+
from PIL import Image
19+
from tqdm import tqdm
20+
from .utils import cvtColor, preprocess_input, resize_image
21+
from .utils_bbox import DecodeBox
22+
from .utils_map import get_coco_map, get_map
1223

1324

1425
class LossHistory(keras.callbacks.Callback):
@@ -97,6 +108,169 @@ def on_epoch_end(self, batch, logs=None):
97108
if self.verbose > 0:
98109
print('Setting learning rate to %s.' % (learning_rate))
99110

111+
class EvalCallback(keras.callbacks.Callback):
112+
def __init__(self, model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir,\
113+
map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
114+
super(EvalCallback, self).__init__()
115+
116+
self.model_body = model_body
117+
self.input_shape = input_shape
118+
self.anchors = anchors
119+
self.anchors_mask = anchors_mask
120+
self.class_names = class_names
121+
self.num_classes = num_classes
122+
self.val_lines = val_lines
123+
self.log_dir = log_dir
124+
self.map_out_path = map_out_path
125+
self.max_boxes = max_boxes
126+
self.confidence = confidence
127+
self.nms_iou = nms_iou
128+
self.letterbox_image = letterbox_image
129+
self.MINOVERLAP = MINOVERLAP
130+
self.eval_flag = eval_flag
131+
self.period = period
132+
133+
#---------------------------------------------------------#
134+
# 在DecodeBox函数中,我们会对预测结果进行后处理
135+
# 后处理的内容包括,解码、非极大抑制、门限筛选等
136+
#---------------------------------------------------------#
137+
self.input_image_shape = Input([2,],batch_size=1)
138+
inputs = [*self.model_body.output, self.input_image_shape]
139+
outputs = Lambda(
140+
DecodeBox,
141+
output_shape = (1,),
142+
name = 'yolo_eval',
143+
arguments = {
144+
'anchors' : self.anchors,
145+
'num_classes' : self.num_classes,
146+
'input_shape' : self.input_shape,
147+
'anchor_mask' : self.anchors_mask,
148+
'confidence' : self.confidence,
149+
'nms_iou' : self.nms_iou,
150+
'max_boxes' : self.max_boxes,
151+
'letterbox_image' : self.letterbox_image
152+
}
153+
)(inputs)
154+
self.yolo_model = Model([self.model_body.input, self.input_image_shape], outputs)
155+
156+
self.maps = [0]
157+
self.epoches = [0]
158+
if self.eval_flag:
159+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
160+
f.write(str(0))
161+
f.write("\n")
162+
163+
@tf.function
164+
def get_pred(self, image_data, input_image_shape):
165+
out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False)
166+
return out_boxes, out_scores, out_classes
167+
168+
def get_map_txt(self, image_id, image, class_names, map_out_path):
169+
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
170+
#---------------------------------------------------------#
171+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
172+
#---------------------------------------------------------#
173+
image = cvtColor(image)
174+
#---------------------------------------------------------#
175+
# 给图像增加灰条,实现不失真的resize
176+
# 也可以直接resize进行识别
177+
#---------------------------------------------------------#
178+
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
179+
#---------------------------------------------------------#
180+
# 添加上batch_size维度,并进行归一化
181+
#---------------------------------------------------------#
182+
image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
183+
184+
#---------------------------------------------------------#
185+
# 将图像输入网络当中进行预测!
186+
#---------------------------------------------------------#
187+
input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0)
188+
outputs = self.get_pred(image_data, input_image_shape)
189+
out_boxes, out_scores, out_classes = [out.numpy() for out in outputs]
190+
191+
top_100 = np.argsort(out_scores)[::-1][:self.max_boxes]
192+
out_boxes = out_boxes[top_100]
193+
out_scores = out_scores[top_100]
194+
out_classes = out_classes[top_100]
195+
196+
for i, c in enumerate(out_classes):
197+
predicted_class = self.class_names[int(c)]
198+
try:
199+
score = str(out_scores[i].numpy())
200+
except:
201+
score = str(out_scores[i])
202+
top, left, bottom, right = out_boxes[i]
203+
if predicted_class not in class_names:
204+
continue
205+
206+
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
207+
208+
f.close()
209+
return
210+
211+
def on_epoch_end(self, epoch, logs=None):
212+
temp_epoch = epoch + 1
213+
if temp_epoch % self.period == 0 and self.eval_flag:
214+
if not os.path.exists(self.map_out_path):
215+
os.makedirs(self.map_out_path)
216+
if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
217+
os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
218+
if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
219+
os.makedirs(os.path.join(self.map_out_path, "detection-results"))
220+
print("Get map.")
221+
for annotation_line in tqdm(self.val_lines):
222+
line = annotation_line.split()
223+
image_id = os.path.basename(line[0]).split('.')[0]
224+
#------------------------------#
225+
# 读取图像并转换成RGB图像
226+
#------------------------------#
227+
image = Image.open(line[0])
228+
#------------------------------#
229+
# 获得预测框
230+
#------------------------------#
231+
gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
232+
#------------------------------#
233+
# 获得预测txt
234+
#------------------------------#
235+
self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
236+
237+
#------------------------------#
238+
# 获得真实框txt
239+
#------------------------------#
240+
with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
241+
for box in gt_boxes:
242+
left, top, right, bottom, obj = box
243+
obj_name = self.class_names[obj]
244+
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
245+
246+
print("Calculate Map.")
247+
try:
248+
temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
249+
except:
250+
temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
251+
self.maps.append(temp_map)
252+
self.epoches.append(temp_epoch)
253+
254+
with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
255+
f.write(str(temp_map))
256+
f.write("\n")
257+
258+
plt.figure()
259+
plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
260+
261+
plt.grid(True)
262+
plt.xlabel('Epoch')
263+
plt.ylabel('Map %s'%str(self.MINOVERLAP))
264+
plt.title('A Map Curve')
265+
plt.legend(loc="upper right")
266+
267+
plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
268+
plt.cla()
269+
plt.close("all")
270+
271+
print("Get map done.")
272+
shutil.rmtree(self.map_out_path)
273+
100274
class ModelCheckpoint(keras.callbacks.Callback):
101275
def __init__(self, filepath, monitor='val_loss', verbose=0,
102276
save_best_only=False, save_weights_only=False,

0 commit comments

Comments
 (0)