-
Notifications
You must be signed in to change notification settings - Fork 548
/
Copy pathtrt_googlenet_async.py
184 lines (151 loc) · 5.89 KB
/
trt_googlenet_async.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""trt_googlenet.py
This is the 'async' version of trt_googlenet.py implementation.
Refer to trt_ssd_async.py for description about the design and
synchronization between the main and child threads.
"""
import sys
import time
import argparse
import threading
import numpy as np
import cv2
from utils.camera import add_camera_args, Camera
from utils.display import open_window, set_display, show_fps
from pytrt import PyTrtGooglenet
PIXEL_MEANS = np.array([[[104., 117., 123.]]], dtype=np.float32)
DEPLOY_ENGINE = 'googlenet/deploy.engine'
ENGINE_SHAPE0 = (3, 224, 224)
ENGINE_SHAPE1 = (1000, 1, 1)
RESIZED_SHAPE = (224, 224)
WINDOW_NAME = 'TrtGooglenetDemo'
MAIN_THREAD_TIMEOUT = 10.0 # 10 seconds
# 'shared' global variables
s_img, s_probs, s_labels = None, None, None
def parse_args():
"""Parse input arguments."""
desc = ('Capture and display live camera video, while doing '
'real-time image classification with TrtGooglenet '
'on Jetson Nano')
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument('--crop', dest='crop_center',
help='crop center square of image for '
'inferencing [False]',
action='store_true')
args = parser.parse_args()
return args
def classify(img, net, labels, do_cropping):
"""Classify 1 image (crop)."""
crop = img
if do_cropping:
h, w, _ = img.shape
if h < w:
crop = img[:, ((w-h)//2):((w+h)//2), :]
else:
crop = img[((h-w)//2):((h+w)//2), :, :]
# preprocess the image crop
crop = cv2.resize(crop, RESIZED_SHAPE)
crop = crop.astype(np.float32) - PIXEL_MEANS
crop = crop.transpose((2, 0, 1)) # HWC -> CHW
# inference the (cropped) image
out = net.forward(crop[None]) # add 1 dimension to 'crop' as batch
# output top 3 predicted scores and class labels
out_prob = np.squeeze(out['prob'][0])
top_inds = out_prob.argsort()[::-1][:3]
return (out_prob[top_inds], labels[top_inds])
class TrtGooglenetThread(threading.Thread):
def __init__(self, condition, cam, labels, do_cropping):
"""__init__
# Arguments
condition: the condition variable used to notify main
thread about new frame and detection result
cam: the camera object for reading input image frames
labels: a numpy array of class labels
do_cropping: whether to do center-cropping of input image
"""
threading.Thread.__init__(self)
self.condition = condition
self.cam = cam
self.labels = labels
self.do_cropping = do_cropping
self.running = False
def run(self):
"""Run until 'running' flag is set to False by main thread."""
global s_img, s_probs, s_labels
print('TrtGooglenetThread: loading the TRT Googlenet engine...')
self.net = PyTrtGooglenet(DEPLOY_ENGINE, ENGINE_SHAPE0, ENGINE_SHAPE1)
print('TrtGooglenetThread: start running...')
self.running = True
while self.running:
img = self.cam.read()
if img is None:
break
top_probs, top_labels = classify(
img, self.net, self.labels, self.do_cropping)
with self.condition:
s_img, s_probs, s_labels = img, top_probs, top_labels
self.condition.notify()
del self.net
print('TrtGooglenetThread: stopped...')
def stop(self):
self.running = False
self.join()
def show_top_preds(img, top_probs, top_labels):
"""Show top predicted classes and softmax scores."""
x = 10
y = 40
for prob, label in zip(top_probs, top_labels):
pred = '{:.4f} {:20s}'.format(prob, label)
#cv2.putText(img, pred, (x+1, y), cv2.FONT_HERSHEY_PLAIN, 1.0,
# (32, 32, 32), 4, cv2.LINE_AA)
cv2.putText(img, pred, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0,
(0, 0, 240), 1, cv2.LINE_AA)
y += 20
def loop_and_display(condition):
"""Continuously capture images from camera and do classification."""
global s_img, s_probs, s_labels
full_scrn = False
fps = 0.0
tic = time.time()
while True:
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0:
break
with condition:
if condition.wait(timeout=MAIN_THREAD_TIMEOUT):
img, top_probs, top_labels = s_img, s_probs, s_labels
else:
raise SystemExit('ERROR: timeout waiting for img from child')
show_top_preds(img, top_probs, top_labels)
img = show_fps(img, fps)
cv2.imshow(WINDOW_NAME, img)
toc = time.time()
curr_fps = 1.0 / (toc - tic)
# calculate an exponentially decaying average of fps number
fps = curr_fps if fps == 0.0 else (fps*0.95 + curr_fps*0.05)
tic = toc
key = cv2.waitKey(1)
if key == 27: # ESC key: quit program
break
elif key == ord('H') or key == ord('h'): # Toggle help message
show_help = not show_help
elif key == ord('F') or key == ord('f'): # Toggle fullscreen
full_scrn = not full_scrn
set_display(WINDOW_NAME, full_scrn)
def main():
args = parse_args()
labels = np.loadtxt('googlenet/synset_words.txt', str, delimiter='\t')
cam = Camera(args)
if not cam.isOpened():
raise SystemExit('ERROR: failed to open camera!')
open_window(
WINDOW_NAME, 'Camera TensorRT GoogLeNet Demo',
cam.img_width, cam.img_height)
condition = threading.Condition()
trt_thread = TrtGooglenetThread(condition, cam, labels, args.crop_center)
trt_thread.start() # start the child thread
loop_and_display(condition)
trt_thread.stop() # stop the child thread
cam.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
main()