Skip to content

Commit

Permalink
parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
mingdong-li committed Oct 14, 2019
1 parent e67f4a8 commit 06f9c07
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 1 deletion.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
*.xmind
*.pt

part_run.sh
multi1_initialize.sh
multi2_detect.sh
multi3_run.sh


part_test.sh
*.out

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
64 changes: 64 additions & 0 deletions detector/to_img_by_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os, sys
import argparse
import numpy as np
from cv2 import cv2


def to_img(video):
# video config
video_root = "./video"
video_path = os.path.join(video_root, video)
cap = cv2.VideoCapture(video_path)
frame_index = 1

data = []
if cap.isOpened():
success = True
else:
success = False
print("读取失败!")


zj_path = "./result/img"
os.makedirs(zj_path, exist_ok=True)

if video[:2] not in os.listdir(zj_path):
os.mkdir(os.path.join(zj_path, video[:2]))

vdo_savepath = os.path.join(zj_path, video[:2])
for i in ['img1','det']:
if i not in os.listdir(vdo_savepath):
os.mkdir(os.path.join(vdo_savepath, i))

img_savepath = os.path.join(vdo_savepath, 'img1')
det_savepath = os.path.join(vdo_savepath, 'det')

frame_item = []
while (success):
# frame.shape = (1080,1920,3)
success, frame = cap.read()
if success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
ori_img = frame

path = os.path.join(img_savepath, str(frame_index).zfill(6))
path = path+".jpg"
# image save
cv2.imwrite(path, ori_img[:,:,(2,1,0)])

print(video,": frame",frame_index)
frame_index += 1


print("%s finished"%video)
cap.release()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--file', type=str, help='file name')

opt = parser.parse_args()


to_img(opt.file)
239 changes: 239 additions & 0 deletions single_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import os
import cv2
import sys
import nms.nms as nms
import glob
import time
import torch
import shutil
import numpy as np
import yacs
import argparse


root = os.getcwd()
track_dir = os.path.join(root, 'pysot')
params_dir = os.path.join(root, 'weights')
sys.path.append(root)
sys.path.append(track_dir)

def main(seq_name):
# the packages of trackers
from pysot.core.config import cfg # use the modified config file to reset the tracking system
from pysot.models.model_builder import ModelBuilder
# modified single tracker with warpper
from mot_zj.MUST_sot_builder import build_tracker
from mot_zj.MUST_utils import draw_bboxes, find_candidate_detection, handle_conflicting_trackers, sort_trackers
from mot_zj.MUST_ASSO.MUST_asso_model import AssociationModel
from mot_zj.MUST_utils import traj_interpolate



dataset_dir = os.path.join(root, 'result')
seq_type = 'img'
# set the path of config parameters and
config_path = os.path.join(track_dir, "mot_zj","MUST_config_file","alex_config.yaml")
model_params = os.path.join(params_dir, "alex_model.pth")
# enable the visualisation or not
is_visualisation = False
# print the information of the tracking process or not
is_print = True

results_dir = os.path.join(dataset_dir,'track')
if not os.path.exists(results_dir):
os.makedirs(results_dir)
img_traj_dir = os.path.join(track_dir, "img_traj")
if os.path.exists(os.path.join(img_traj_dir, seq_name)):
shutil.rmtree(os.path.join(img_traj_dir, seq_name))

seq_dir = os.path.join(dataset_dir, seq_type)
seq_names = os.listdir(seq_dir)
seq_num = len(seq_names)

# record the processing time
start_point = time.time()

# load config
# load the config information from other variables
cfg.merge_from_file(config_path)

# set the flag that CUDA is available
cfg.CUDA = torch.cuda.is_available()
device = torch.device('cuda' if cfg.CUDA else 'cpu')

# create the tracker model (Resnet50)
track_model = ModelBuilder()
# load tracker model
track_model.load_state_dict(torch.load(model_params, map_location=lambda storage, loc: storage.cpu()))
track_model.eval().to(device)
# create assoiation model
asso_model = AssociationModel()

seq_det_path = os.path.join(seq_dir, seq_name, 'det')
seq_img_path = os.path.join(seq_dir, seq_name, 'img1')

# print path and dataset information
if is_print:
print('preparing for the sequence: {}'.format(seq_name))
print('-----------------------------------------------')
print("detection result path: {}".format(seq_det_path))
print("image files path: {}".format(seq_img_path))
print('-----------------------------------------------')

# read the detection results
det_results = np.loadtxt(os.path.join(seq_det_path, 'det.txt'), dtype=float, delimiter=',')

# read images from each sequence
images = sorted(glob.glob(os.path.join(seq_img_path, '*.jpg')))
img_num = len(images)

# the contrainer of trackers
trackers = []

# visualisation settings
if is_visualisation:
cv2.namedWindow(seq_name, cv2.WINDOW_NORMAL)

# init(reset) the identifer
id_num = 0

# tracking process in each frame
for nn, im_path in enumerate(images):
each_start = time.time()
frame = nn + 1
img = cv2.imread(im_path)
print('Frame {} is loaded'.format(frame))

# load the detection results of this frame
pre_frame_det_results = det_results[det_results[:,0] == frame]

# non-maximal surpressing [frame, id, x, y, w, h, score]
indices = nms.boxes(pre_frame_det_results[:,2:6], pre_frame_det_results[:,6])
frame_det_results = pre_frame_det_results[indices,:]

# extract the bbox [fr, id, (x, y, w, h), score]
bboxes = frame_det_results[:, 2:6]

############################################
# ***multiple tracking and associating*** #
############################################

# 1. sort trackers
index1, index2 = sort_trackers(trackers)

# 2. save the processed index of trackers
index_processed = []
track_time = 0;
asso_time = 0;
for k in range(2):
# process trackers in the first or the second class
if k == 0:
index_track = index1
else:
index_track = index2

for ind in index_track:
if trackers[ind].track_state == cfg.STATE.TRACKED or trackers[ind].track_state == cfg.STATE.ACTIVATED:
indices = find_candidate_detection([trackers[i] for i in index_processed], bboxes)
to_track_bboxes = bboxes[indices, :] if not bboxes.size == 0 else np.array([])
# MOT_track(tracking process)
trackers[ind].track(img, to_track_bboxes, frame)
# if the tracker keep its previous tracking state (tracked or activated)
if trackers[ind].track_state == cfg.STATE.TRACKED or trackers[ind].track_state == cfg.STATE.ACTIVATED:
index_processed.append(ind)

for ind in index_track:
if trackers[ind].track_state == cfg.STATE.LOST:
indices = find_candidate_detection([trackers[i] for i in index_processed], bboxes)
to_associate_bboxes = bboxes[indices, :] if not bboxes.size == 0 else np.array([])
# MOT_track(association process)
trackers[ind].track(img, to_track_bboxes, frame)
# add process flag
index_processed.append(ind)

############################################
# ***init new trackers *** #
############################################

# find the candidate bboxes to init new trackers
indices = find_candidate_detection(trackers, bboxes)

# process the tracker: init (1st frame) and track mathod (the other frames)
for index in indices:
id_num += 1
new_tracker = build_tracker(track_model)
new_tracker.init(img, bboxes[index, :], id_num, frame, seq_name, asso_model)
trackers.append(new_tracker)

# find conflict of trackers (I need to know what conflict)
trackers = handle_conflicting_trackers(trackers, bboxes)

# interpolate the tracklet results
for tracker in trackers:
if tracker.track_state == cfg.STATE.TRACKED or tracker.track_state == cfg.STATE.ACTIVATED:
bbox = tracker.tracking_bboxes[-1, :]
traj_interpolate(tracker, bbox, tracker.frames[-1], 30)

############################################
# ***collect tracking results*** #
############################################

# collect the tracking results (all the results, without selected)
if frame % 500 == 0:
results_bboxes = np.array([])
for tracker in trackers:
if results_bboxes.size == 0:
results_bboxes = tracker.results_return()
else:
res = tracker.results_return()
if not res.size == 0:
results_bboxes = np.concatenate((results_bboxes, tracker.results_return()), axis=0)
# test code segment
filename = '{}.txt'.format(seq_name)
results_bboxes = results_bboxes[np.argsort(results_bboxes[:, 0])]
print(results_bboxes.shape[0])
# detections filter
indices = []
if seq_name == 'b1':
for ind, result in enumerate(results_bboxes):
if result[3] > 540:
if result[4]*result[5] < 10000:
indices.append(ind)
results_bboxes = np.delete(results_bboxes, indices, axis = 0)
np.savetxt(os.path.join(results_dir,filename), results_bboxes, fmt='%d,%d,%.1f,%.1f,%.1f,%.1f')
############################################
# ***crop tracklet image*** #
############################################

for tracker in trackers:
if tracker.track_state == cfg.STATE.START or tracker.track_state == cfg.STATE.TRACKED or tracker.track_state == cfg.STATE.ACTIVATED:
bbox = tracker.tracking_bboxes[-1, :]
x1 = int(np.floor(np.maximum(1, bbox[0])))
y1 = int(np.ceil(np.maximum(1, bbox[1])))
x2 = int(np.ceil(np.minimum(img.shape[1], bbox[0]+bbox[2])))
y2 = int(np.ceil(np.minimum(img.shape[0], bbox[1]+bbox[3])))
img_traj = img[y1:y2, x1:x2, :]
traj_path = os.path.join(img_traj_dir, seq_name, str(tracker.id_num))
if not os.path.exists(traj_path):
os.makedirs(traj_path)
tracklet_img_path = os.path.join(traj_path, str(tracker.frames[-1]))
cv2.imwrite("{}.jpg".format(tracklet_img_path), img_traj)
if is_visualisation:
##########################################
# infomation print and visualisation #
##########################################
# print("THe numger of new trackers: {}".format(len(indices)))
active_trackers = [trackers[i].id_num for i in range(len(trackers)) if trackers[i].track_state == cfg.STATE.ACTIVATED or trackers[i].track_state == cfg.STATE.TRACKED or trackers[i].track_state == cfg.STATE.LOST]
print("The number of active trackers: {}".format(len(active_trackers)))
print(active_trackers)
anno_img = draw_bboxes(img, bboxes)
cv2.imshow(seq_name, anno_img)
cv2.waitKey(1)

print("The total processing time is: {} s".format(time.time()-start_point))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--seq_name", type=str, default='b1')
args = parser.parse_args()
main(args.seq_name)

0 comments on commit 06f9c07

Please sign in to comment.