-
Notifications
You must be signed in to change notification settings - Fork 3
/
preprocess.py
130 lines (99 loc) · 5.07 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import os
from datetime import datetime
import numpy as np
import cv2
# from torchvision.utils import save_image
from tqdm import tqdm
import face_alignment
from matplotlib import pyplot as plt
path_to_mp4 = './voxceleb/dev/mp4'
K = 8
num_vid = 0
device = torch.device('cuda:0')
saves_dir = './vox2_mp4_preprocess'
face_aligner = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device='cuda:0')
if not os.path.isdir(saves_dir):
os.mkdir(saves_dir)
def generate_landmarks(frames_list, face_aligner):
frame_landmark_list = []
fa = face_aligner
for i in range(len(frames_list)):
try:
input = frames_list[i]
preds = fa.get_landmarks(input)[0]
dpi = 100
fig = plt.figure(figsize=(input.shape[1] / dpi, input.shape[0] / dpi), dpi=dpi)
ax = fig.add_subplot(1, 1, 1)
ax.imshow(np.ones(input.shape))
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
# chin
ax.plot(preds[0:17, 0], preds[0:17, 1], marker='', markersize=5, linestyle='-', color='green', lw=2)
# left and right eyebrow
ax.plot(preds[17:22, 0], preds[17:22, 1], marker='', markersize=5, linestyle='-', color='orange', lw=2)
ax.plot(preds[22:27, 0], preds[22:27, 1], marker='', markersize=5, linestyle='-', color='orange', lw=2)
# nose
ax.plot(preds[27:31, 0], preds[27:31, 1], marker='', markersize=5, linestyle='-', color='blue', lw=2)
ax.plot(preds[31:36, 0], preds[31:36, 1], marker='', markersize=5, linestyle='-', color='blue', lw=2)
# left and right eye
ax.plot(preds[36:42, 0], preds[36:42, 1], marker='', markersize=5, linestyle='-', color='red', lw=2)
ax.plot(preds[42:48, 0], preds[42:48, 1], marker='', markersize=5, linestyle='-', color='red', lw=2)
# outer and inner lip
ax.plot(preds[48:60, 0], preds[48:60, 1], marker='', markersize=5, linestyle='-', color='purple', lw=2)
# ax.plot(preds[60:68, 0], preds[60:68, 1], marker='', markersize=5, linestyle='-', color='pink', lw=2)
ax.axis('off')
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
frame_landmark_list.append((input, data))
plt.close(fig)
except:
print('Error: Video corrupted or no landmarks visible')
for i in range(len(frames_list) - len(frame_landmark_list)):
# filling frame_landmark_list in case of error
frame_landmark_list.append(frame_landmark_list[i])
return frame_landmark_list
def pick_images(video_path, pic_folder, num_images):
cap = cv2.VideoCapture(video_path)
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
idxes = [1 if i % (n_frames // num_images + 1) == 0 else 0 for i in range(n_frames)]
frames_list = []
# Read until video is completed or no frames needed
ret = True
frame_idx = 0
frame_counter = 0
while (ret and frame_idx < n_frames):
ret, frame = cap.read()
if ret and idxes[frame_idx] == 1:
RGB = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames_list.append(RGB)
# frame_counter += 1
# pic_path = pic_folder + '_'+str(frame_counter)+'.jpg'
# cv2.imwrite(pic_path, frame)
frame_idx += 1
cap.release()
return frames_list
for person_id in tqdm(os.listdir(path_to_mp4)):
for video_id in tqdm(os.listdir(os.path.join(path_to_mp4, person_id))):
for video in os.listdir(os.path.join(path_to_mp4, person_id, video_id)):
try:
video_path = os.path.join(path_to_mp4, person_id, video_id, video)
frame_mark = pick_images(video_path,
saves_dir + '/' + person_id + '/' + video_id + '/' + video.split('.')[0], K)
frame_mark = generate_landmarks(frame_mark, face_aligner)
if len(frame_mark) == K:
final_list = [frame_mark[i][0] for i in range(K)]
for i in range(K):
final_list.append(frame_mark[i][1]) # K*2,224,224,3
final_list = np.array(final_list)
final_list = np.transpose(final_list, [1, 0, 2, 3])
final_list = np.reshape(final_list, (224, 224 * 2 * K, 3))
final_list = cv2.cvtColor(final_list, cv2.COLOR_BGR2RGB)
if not os.path.isdir(saves_dir + "/" + str(num_vid // 256)):
os.mkdir(saves_dir + "/" + str(num_vid // 256))
cv2.imwrite(saves_dir + "/" + str(num_vid // 256) + "/" + str(num_vid) + ".png", final_list)
num_vid += 1
break # take only one video
except:
print('ERROR: ', video_path)
print('done')