Skip to content
This repository was archived by the owner on Dec 19, 2023. It is now read-only.

Commit 6379d01

Browse files
committed
add yolov4 model
1 parent d56d0b7 commit 6379d01

File tree

5 files changed

+69
-683
lines changed

5 files changed

+69
-683
lines changed

tensorlayer/app/computer_vision.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33

4-
from tensorlayer.app import YOLOv4
4+
from tensorlayer.app import YOLOv4, get_anchors, decode, filter_boxes
55
import numpy as np
66
import tensorflow as tf
7+
from tensorlayer import logging
8+
import cv2
79

810

911
class object_detection(object):
@@ -17,19 +19,70 @@ def __init__(self, model_name='yolo4-mscoco'):
1719

1820
def __call__(self, input_data):
1921
if self.model_name == 'yolo4-mscoco':
20-
image_data = input_data / 255.
21-
images_data = []
22-
for i in range(1):
23-
images_data.append(image_data)
24-
images_data = np.asarray(images_data).astype(np.float32)
25-
batch_data = tf.constant(images_data)
26-
output = self.model(batch_data, is_train=False)
22+
batch_data = yolo4_input_processing(input_data)
23+
feature_maps = self.model(batch_data, is_train=False)
24+
output = yolo4_output_processing(feature_maps)
2725
else:
2826
raise NotImplementedError
2927

3028
return output
3129

3230
def __repr__(self):
33-
s = ('{classname}(model_name={model_name}, model_structure={model}')
31+
s = ('(model_name={model_name}, model_structure={model}')
3432
s += ')'
3533
return s.format(classname=self.__class__.__name__, **self.__dict__)
34+
35+
@property
36+
def list(self):
37+
logging.info("The model name list: yolov4-mscoco")
38+
39+
40+
def yolo4_input_processing(original_image):
41+
image_data = cv2.resize(original_image, (416, 416))
42+
image_data = image_data / 255.
43+
images_data = []
44+
for i in range(1):
45+
images_data.append(image_data)
46+
images_data = np.asarray(images_data).astype(np.float32)
47+
batch_data = tf.constant(images_data)
48+
return batch_data
49+
50+
51+
def yolo4_output_processing(feature_maps):
52+
STRIDES = [8, 16, 32]
53+
ANCHORS = get_anchors([12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401])
54+
NUM_CLASS = 80
55+
XYSCALE = [1.2, 1.1, 1.05]
56+
iou_threshold = 0.45
57+
score_threshold = 0.25
58+
59+
bbox_tensors = []
60+
prob_tensors = []
61+
score_thres = 0.2
62+
for i, fm in enumerate(feature_maps):
63+
if i == 0:
64+
output_tensors = decode(fm, 416 // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
65+
elif i == 1:
66+
output_tensors = decode(fm, 416 // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
67+
else:
68+
output_tensors = decode(fm, 416 // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE)
69+
bbox_tensors.append(output_tensors[0])
70+
prob_tensors.append(output_tensors[1])
71+
pred_bbox = tf.concat(bbox_tensors, axis=1)
72+
pred_prob = tf.concat(prob_tensors, axis=1)
73+
boxes, pred_conf = filter_boxes(
74+
pred_bbox, pred_prob, score_threshold=score_thres, input_shape=tf.constant([416, 416])
75+
)
76+
pred = {'concat': tf.concat([boxes, pred_conf], axis=-1)}
77+
78+
for key, value in pred.items():
79+
boxes = value[:, :, 0:4]
80+
pred_conf = value[:, :, 4:]
81+
82+
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
83+
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
84+
scores=tf.reshape(pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
85+
max_output_size_per_class=50, max_total_size=50, iou_threshold=iou_threshold, score_threshold=score_threshold
86+
)
87+
output = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
88+
return output

tensorlayer/app/computer_vision_object_detection/model/coco.names

Lines changed: 0 additions & 80 deletions
This file was deleted.

0 commit comments

Comments
 (0)