1
1
#! /usr/bin/python
2
2
# -*- coding: utf-8 -*-
3
3
4
- from tensorlayer .app import YOLOv4
4
+ from tensorlayer .app import YOLOv4 , get_anchors , decode , filter_boxes
5
5
import numpy as np
6
6
import tensorflow as tf
7
+ from tensorlayer import logging
8
+ import cv2
7
9
8
10
9
11
class object_detection (object ):
@@ -17,19 +19,70 @@ def __init__(self, model_name='yolo4-mscoco'):
17
19
18
20
def __call__ (self , input_data ):
19
21
if self .model_name == 'yolo4-mscoco' :
20
- image_data = input_data / 255.
21
- images_data = []
22
- for i in range (1 ):
23
- images_data .append (image_data )
24
- images_data = np .asarray (images_data ).astype (np .float32 )
25
- batch_data = tf .constant (images_data )
26
- output = self .model (batch_data , is_train = False )
22
+ batch_data = yolo4_input_processing (input_data )
23
+ feature_maps = self .model (batch_data , is_train = False )
24
+ output = yolo4_output_processing (feature_maps )
27
25
else :
28
26
raise NotImplementedError
29
27
30
28
return output
31
29
32
30
def __repr__ (self ):
33
- s = ('{classname} (model_name={model_name}, model_structure={model}' )
31
+ s = ('(model_name={model_name}, model_structure={model}' )
34
32
s += ')'
35
33
return s .format (classname = self .__class__ .__name__ , ** self .__dict__ )
34
+
35
+ @property
36
+ def list (self ):
37
+ logging .info ("The model name list: yolov4-mscoco" )
38
+
39
+
40
+ def yolo4_input_processing (original_image ):
41
+ image_data = cv2 .resize (original_image , (416 , 416 ))
42
+ image_data = image_data / 255.
43
+ images_data = []
44
+ for i in range (1 ):
45
+ images_data .append (image_data )
46
+ images_data = np .asarray (images_data ).astype (np .float32 )
47
+ batch_data = tf .constant (images_data )
48
+ return batch_data
49
+
50
+
51
+ def yolo4_output_processing (feature_maps ):
52
+ STRIDES = [8 , 16 , 32 ]
53
+ ANCHORS = get_anchors ([12 , 16 , 19 , 36 , 40 , 28 , 36 , 75 , 76 , 55 , 72 , 146 , 142 , 110 , 192 , 243 , 459 , 401 ])
54
+ NUM_CLASS = 80
55
+ XYSCALE = [1.2 , 1.1 , 1.05 ]
56
+ iou_threshold = 0.45
57
+ score_threshold = 0.25
58
+
59
+ bbox_tensors = []
60
+ prob_tensors = []
61
+ score_thres = 0.2
62
+ for i , fm in enumerate (feature_maps ):
63
+ if i == 0 :
64
+ output_tensors = decode (fm , 416 // 8 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
65
+ elif i == 1 :
66
+ output_tensors = decode (fm , 416 // 16 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
67
+ else :
68
+ output_tensors = decode (fm , 416 // 32 , NUM_CLASS , STRIDES , ANCHORS , i , XYSCALE )
69
+ bbox_tensors .append (output_tensors [0 ])
70
+ prob_tensors .append (output_tensors [1 ])
71
+ pred_bbox = tf .concat (bbox_tensors , axis = 1 )
72
+ pred_prob = tf .concat (prob_tensors , axis = 1 )
73
+ boxes , pred_conf = filter_boxes (
74
+ pred_bbox , pred_prob , score_threshold = score_thres , input_shape = tf .constant ([416 , 416 ])
75
+ )
76
+ pred = {'concat' : tf .concat ([boxes , pred_conf ], axis = - 1 )}
77
+
78
+ for key , value in pred .items ():
79
+ boxes = value [:, :, 0 :4 ]
80
+ pred_conf = value [:, :, 4 :]
81
+
82
+ boxes , scores , classes , valid_detections = tf .image .combined_non_max_suppression (
83
+ boxes = tf .reshape (boxes , (tf .shape (boxes )[0 ], - 1 , 1 , 4 )),
84
+ scores = tf .reshape (pred_conf , (tf .shape (pred_conf )[0 ], - 1 , tf .shape (pred_conf )[- 1 ])),
85
+ max_output_size_per_class = 50 , max_total_size = 50 , iou_threshold = iou_threshold , score_threshold = score_threshold
86
+ )
87
+ output = [boxes .numpy (), scores .numpy (), classes .numpy (), valid_detections .numpy ()]
88
+ return output
0 commit comments