Skip to content

Commit 486da5a

Browse files
...
1 parent 98c1f7f commit 486da5a

File tree

1 file changed

+156
-0
lines changed

1 file changed

+156
-0
lines changed

scripts/ObjectTrackingDeepSORT.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
2+
import cv2
3+
import numpy as np
4+
import sys
5+
import glob
6+
7+
import time
8+
import torch
9+
10+
11+
12+
class YoloDetector():
13+
14+
def __init__(self, model_name):
15+
16+
self.model = self.load_model(model_name)
17+
self.classes = self.model.names
18+
#print(self.classes)
19+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20+
print("Using Device: ", self.device)
21+
22+
23+
def load_model(self, model_name):
24+
25+
if model_name:
26+
model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_name, force_reload=True)
27+
else:
28+
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
29+
return model
30+
31+
def score_frame(self, frame):
32+
33+
self.model.to(self.device)
34+
downscale_factor = 2
35+
width = int(frame.shape[1] / downscale_factor)
36+
height = int(frame.shape[0] / downscale_factor)
37+
frame = cv2.resize(frame, (width,height))
38+
#frame = frame.to(self.device)
39+
40+
results = self.model(frame)
41+
42+
labels, cord = results.xyxyn[0][:, -1], results.xyxyn[0][:, :-1]
43+
44+
return labels, cord
45+
46+
def class_to_label(self, x):
47+
48+
return self.classes[int(x)]
49+
50+
51+
def plot_boxes(self, results, frame, height, width, confidence=0.3):
52+
53+
labels, cord = results
54+
detections = []
55+
56+
n = len(labels)
57+
x_shape, y_shape = width, height
58+
59+
60+
61+
for i in range(n):
62+
row = cord[i]
63+
64+
if row[4] >= confidence:
65+
x1, y1, x2, y2 = int(row[0]*x_shape), int(row[1]*y_shape), int(row[2]*x_shape), int(row[3]*y_shape)
66+
67+
if self.class_to_label(labels[i]) == 'cup':
68+
69+
x_center = x1 + (x2 - x1)
70+
y_center = y1 + ((y2 - y1) / 2)
71+
72+
tlwh = np.asarray([x1, y1, int(x2-x1), int(y2-y1)], dtype=np.float32)
73+
confidence = float(row[4].item())
74+
feature = 'person'
75+
76+
detections.append(([x1, y1, int(x2-x1), int(y2-y1)], row[4].item(), 'person'))
77+
78+
79+
return frame, detections
80+
81+
82+
cap = cv2.VideoCapture(0)
83+
84+
85+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
86+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
87+
88+
89+
detector = YoloDetector(model_name=None)
90+
91+
import os
92+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
93+
94+
95+
from deep_sort_realtime.deepsort_tracker import DeepSort
96+
97+
object_tracker = DeepSort(max_age=5,
98+
n_init=2,
99+
nms_max_overlap=1.0,
100+
max_cosine_distance=0.3,
101+
nn_budget=None,
102+
override_track_class=None,
103+
embedder="mobilenet",
104+
half=True,
105+
bgr=True,
106+
embedder_gpu=True,
107+
embedder_model_name=None,
108+
embedder_wts=None,
109+
polygon=False,
110+
today=None)
111+
112+
113+
while cap.isOpened():
114+
115+
succes, img = cap.read()
116+
117+
start = time.perf_counter()
118+
119+
results = detector.score_frame(img)
120+
img, detections = detector.plot_boxes(results, img, height=img.shape[0], width=img.shape[1], confidence=0.5)
121+
122+
123+
tracks = object_tracker.update_tracks(detections, frame=img) # bbs expected to be a list of detections, each in tuples of ( [left,top,w,h], confidence, detection_class )
124+
125+
126+
for track in tracks:
127+
if not track.is_confirmed():
128+
continue
129+
track_id = track.track_id
130+
ltrb = track.to_ltrb()
131+
132+
bbox = ltrb
133+
134+
cv2.rectangle(img,(int(bbox[0]), int(bbox[1])),(int(bbox[2]), int(bbox[3])),(0,0,255),2)
135+
cv2.putText(img, "ID: " + str(track_id), (int(bbox[0]), int(bbox[1] - 10)), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
136+
137+
138+
139+
end = time.perf_counter()
140+
totalTime = end - start
141+
fps = 1 / totalTime
142+
143+
144+
cv2.putText(img, f'FPS: {int(fps)}', (20,70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0,255,0), 2)
145+
cv2.imshow('img',img)
146+
147+
148+
if cv2.waitKey(1) & 0xFF == 27:
149+
break
150+
151+
152+
# Release and destroy all windows before termination
153+
cap.release()
154+
155+
cv2.destroyAllWindows()
156+

0 commit comments

Comments
 (0)