forked from yastrebksv/TennisCourtDetector
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer_in_video.py
109 lines (92 loc) · 3.32 KB
/
infer_in_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import argparse
from lib.tracknet import BallTrackerNet
from lib.postprocess import postprocess, refine_kps
from lib.homography import get_trans_matrix, refer_kps
def read_video(path_video):
"""Read video file
:params
path_video: path to video file
:return
frames: list of video frames
fps: frames per second
"""
cap = cv2.VideoCapture(path_video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
break
cap.release()
return frames, fps
def write_video(imgs_new, fps, path_output_video):
height, width = imgs_new[0].shape[:2]
out = cv2.VideoWriter(
path_output_video, cv2.VideoWriter_fourcc(*"DIVX"), fps, (width, height)
)
for num in range(len(imgs_new)):
frame = imgs_new[num]
out.write(frame)
out.release()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="path to model")
parser.add_argument("--input_path", type=str, help="path to input video")
parser.add_argument("--output_path", type=str, help="path to output video")
parser.add_argument(
"--use_refine_kps",
action="store_true",
help="whether to use refine kps postprocessing",
)
parser.add_argument(
"--use_homography",
action="store_true",
help="whether to use homography postprocessing",
)
args = parser.parse_args()
model = BallTrackerNet(out_channels=15)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.load_state_dict(torch.load(args.model_path, map_location=device))
model.eval()
OUTPUT_WIDTH = 640
OUTPUT_HEIGHT = 360
frames, fps = read_video(args.input_path)
frames_upd = []
for image in tqdm(frames):
img = cv2.resize(image, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
inp = img.astype(np.float32) / 255.0
inp = torch.tensor(np.rollaxis(inp, 2, 0))
inp = inp.unsqueeze(0)
out = model(inp.float().to(device))[0]
pred = F.sigmoid(out).detach().cpu().numpy()
points = []
for kps_num in range(14):
heatmap = (pred[kps_num] * 255).astype(np.uint8)
x_pred, y_pred = postprocess(heatmap, thresh=170, max_radius=25)
if args.use_refine_kps and kps_num not in [8, 12, 9] and x_pred and y_pred:
x_pred, y_pred = refine_kps(image, int(y_pred), int(x_pred))
points.append((x_pred, y_pred))
if args.use_homography:
matrix_trans = get_trans_matrix(points)
if matrix_trans is not None:
points = cv2.perspectiveTransform(refer_kps, matrix_trans)
points = [np.squeeze(x) for x in points]
for j in range(len(points)):
if points[j][0] is not None:
image = cv2.circle(
image,
(int(points[j][0]), int(points[j][1])),
radius=0,
color=(0, 0, 255),
thickness=10,
)
frames_upd.append(image)
write_video(frames_upd, fps, args.output_path)