|
1 |
| -import math |
2 | 1 | import os
|
| 2 | +import math |
3 | 3 | import warnings
|
4 | 4 |
|
5 | 5 | import matplotlib
|
6 | 6 | matplotlib.use('Agg')
|
7 | 7 | from matplotlib import pyplot as plt
|
8 |
| -import numpy as np |
9 | 8 | import scipy.signal
|
| 9 | + |
| 10 | +import shutil |
| 11 | +import numpy as np |
| 12 | +import tensorflow as tf |
| 13 | + |
10 | 14 | from tensorflow import keras
|
11 | 15 | from tensorflow.keras import backend as K
|
| 16 | +from tensorflow.keras.layers import Input, Lambda |
| 17 | +from tensorflow.keras.models import Model |
| 18 | +from PIL import Image |
| 19 | +from tqdm import tqdm |
| 20 | +from .utils import cvtColor, preprocess_input, resize_image |
| 21 | +from .utils_bbox import DecodeBox |
| 22 | +from .utils_map import get_coco_map, get_map |
12 | 23 |
|
13 | 24 |
|
14 | 25 | class LossHistory(keras.callbacks.Callback):
|
@@ -97,6 +108,169 @@ def on_epoch_end(self, batch, logs=None):
|
97 | 108 | if self.verbose > 0:
|
98 | 109 | print('Setting learning rate to %s.' % (learning_rate))
|
99 | 110 |
|
| 111 | +class EvalCallback(keras.callbacks.Callback): |
| 112 | + def __init__(self, model_body, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir,\ |
| 113 | + map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1): |
| 114 | + super(EvalCallback, self).__init__() |
| 115 | + |
| 116 | + self.model_body = model_body |
| 117 | + self.input_shape = input_shape |
| 118 | + self.anchors = anchors |
| 119 | + self.anchors_mask = anchors_mask |
| 120 | + self.class_names = class_names |
| 121 | + self.num_classes = num_classes |
| 122 | + self.val_lines = val_lines |
| 123 | + self.log_dir = log_dir |
| 124 | + self.map_out_path = map_out_path |
| 125 | + self.max_boxes = max_boxes |
| 126 | + self.confidence = confidence |
| 127 | + self.nms_iou = nms_iou |
| 128 | + self.letterbox_image = letterbox_image |
| 129 | + self.MINOVERLAP = MINOVERLAP |
| 130 | + self.eval_flag = eval_flag |
| 131 | + self.period = period |
| 132 | + |
| 133 | + #---------------------------------------------------------# |
| 134 | + # 在DecodeBox函数中,我们会对预测结果进行后处理 |
| 135 | + # 后处理的内容包括,解码、非极大抑制、门限筛选等 |
| 136 | + #---------------------------------------------------------# |
| 137 | + self.input_image_shape = Input([2,],batch_size=1) |
| 138 | + inputs = [*self.model_body.output, self.input_image_shape] |
| 139 | + outputs = Lambda( |
| 140 | + DecodeBox, |
| 141 | + output_shape = (1,), |
| 142 | + name = 'yolo_eval', |
| 143 | + arguments = { |
| 144 | + 'anchors' : self.anchors, |
| 145 | + 'num_classes' : self.num_classes, |
| 146 | + 'input_shape' : self.input_shape, |
| 147 | + 'anchor_mask' : self.anchors_mask, |
| 148 | + 'confidence' : self.confidence, |
| 149 | + 'nms_iou' : self.nms_iou, |
| 150 | + 'max_boxes' : self.max_boxes, |
| 151 | + 'letterbox_image' : self.letterbox_image |
| 152 | + } |
| 153 | + )(inputs) |
| 154 | + self.yolo_model = Model([self.model_body.input, self.input_image_shape], outputs) |
| 155 | + |
| 156 | + self.maps = [0] |
| 157 | + self.epoches = [0] |
| 158 | + if self.eval_flag: |
| 159 | + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: |
| 160 | + f.write(str(0)) |
| 161 | + f.write("\n") |
| 162 | + |
| 163 | + @tf.function |
| 164 | + def get_pred(self, image_data, input_image_shape): |
| 165 | + out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False) |
| 166 | + return out_boxes, out_scores, out_classes |
| 167 | + |
| 168 | + def get_map_txt(self, image_id, image, class_names, map_out_path): |
| 169 | + f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") |
| 170 | + #---------------------------------------------------------# |
| 171 | + # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 |
| 172 | + #---------------------------------------------------------# |
| 173 | + image = cvtColor(image) |
| 174 | + #---------------------------------------------------------# |
| 175 | + # 给图像增加灰条,实现不失真的resize |
| 176 | + # 也可以直接resize进行识别 |
| 177 | + #---------------------------------------------------------# |
| 178 | + image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image) |
| 179 | + #---------------------------------------------------------# |
| 180 | + # 添加上batch_size维度,并进行归一化 |
| 181 | + #---------------------------------------------------------# |
| 182 | + image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0) |
| 183 | + |
| 184 | + #---------------------------------------------------------# |
| 185 | + # 将图像输入网络当中进行预测! |
| 186 | + #---------------------------------------------------------# |
| 187 | + input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0) |
| 188 | + outputs = self.get_pred(image_data, input_image_shape) |
| 189 | + out_boxes, out_scores, out_classes = [out.numpy() for out in outputs] |
| 190 | + |
| 191 | + top_100 = np.argsort(out_scores)[::-1][:self.max_boxes] |
| 192 | + out_boxes = out_boxes[top_100] |
| 193 | + out_scores = out_scores[top_100] |
| 194 | + out_classes = out_classes[top_100] |
| 195 | + |
| 196 | + for i, c in enumerate(out_classes): |
| 197 | + predicted_class = self.class_names[int(c)] |
| 198 | + try: |
| 199 | + score = str(out_scores[i].numpy()) |
| 200 | + except: |
| 201 | + score = str(out_scores[i]) |
| 202 | + top, left, bottom, right = out_boxes[i] |
| 203 | + if predicted_class not in class_names: |
| 204 | + continue |
| 205 | + |
| 206 | + f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom)))) |
| 207 | + |
| 208 | + f.close() |
| 209 | + return |
| 210 | + |
| 211 | + def on_epoch_end(self, epoch, logs=None): |
| 212 | + temp_epoch = epoch + 1 |
| 213 | + if temp_epoch % self.period == 0 and self.eval_flag: |
| 214 | + if not os.path.exists(self.map_out_path): |
| 215 | + os.makedirs(self.map_out_path) |
| 216 | + if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")): |
| 217 | + os.makedirs(os.path.join(self.map_out_path, "ground-truth")) |
| 218 | + if not os.path.exists(os.path.join(self.map_out_path, "detection-results")): |
| 219 | + os.makedirs(os.path.join(self.map_out_path, "detection-results")) |
| 220 | + print("Get map.") |
| 221 | + for annotation_line in tqdm(self.val_lines): |
| 222 | + line = annotation_line.split() |
| 223 | + image_id = os.path.basename(line[0]).split('.')[0] |
| 224 | + #------------------------------# |
| 225 | + # 读取图像并转换成RGB图像 |
| 226 | + #------------------------------# |
| 227 | + image = Image.open(line[0]) |
| 228 | + #------------------------------# |
| 229 | + # 获得预测框 |
| 230 | + #------------------------------# |
| 231 | + gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) |
| 232 | + #------------------------------# |
| 233 | + # 获得预测txt |
| 234 | + #------------------------------# |
| 235 | + self.get_map_txt(image_id, image, self.class_names, self.map_out_path) |
| 236 | + |
| 237 | + #------------------------------# |
| 238 | + # 获得真实框txt |
| 239 | + #------------------------------# |
| 240 | + with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f: |
| 241 | + for box in gt_boxes: |
| 242 | + left, top, right, bottom, obj = box |
| 243 | + obj_name = self.class_names[obj] |
| 244 | + new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom)) |
| 245 | + |
| 246 | + print("Calculate Map.") |
| 247 | + try: |
| 248 | + temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1] |
| 249 | + except: |
| 250 | + temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path) |
| 251 | + self.maps.append(temp_map) |
| 252 | + self.epoches.append(temp_epoch) |
| 253 | + |
| 254 | + with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f: |
| 255 | + f.write(str(temp_map)) |
| 256 | + f.write("\n") |
| 257 | + |
| 258 | + plt.figure() |
| 259 | + plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map') |
| 260 | + |
| 261 | + plt.grid(True) |
| 262 | + plt.xlabel('Epoch') |
| 263 | + plt.ylabel('Map %s'%str(self.MINOVERLAP)) |
| 264 | + plt.title('A Map Curve') |
| 265 | + plt.legend(loc="upper right") |
| 266 | + |
| 267 | + plt.savefig(os.path.join(self.log_dir, "epoch_map.png")) |
| 268 | + plt.cla() |
| 269 | + plt.close("all") |
| 270 | + |
| 271 | + print("Get map done.") |
| 272 | + shutil.rmtree(self.map_out_path) |
| 273 | + |
100 | 274 | class ModelCheckpoint(keras.callbacks.Callback):
|
101 | 275 | def __init__(self, filepath, monitor='val_loss', verbose=0,
|
102 | 276 | save_best_only=False, save_weights_only=False,
|
|
0 commit comments