-
Notifications
You must be signed in to change notification settings - Fork 15
/
eval_sup.py
78 lines (63 loc) · 2.4 KB
/
eval_sup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# This script is for batch processing testing.
import os
import glob
import torch
import torchvision.utils as vutils
import webrtcvad
import scipy.io as sio
import csv
import numpy as np
from mfcc import MFCC
from config import NETWORKS_PARAMETERS
from network import get_network, Generator1D_directMLP
from utils import write_obj_with_colors, voice2face_processed_MeshOut
# initialization
vad_obj = webrtcvad.Vad(2)
mfc_obj = MFCC(nfilt=64, lowerf=20., upperf=7200., samprate=16000, nfft=1024, wlen=0.025)
e_net, _ = get_network('e', NETWORKS_PARAMETERS, train=False)
g_net = Generator1D_directMLP().cuda().eval()
g_net_ckpt = torch.load(NETWORKS_PARAMETERS['g']['model_path'])
g_net.load_state_dict(g_net_ckpt)
# test
voice_list = sorted(glob.glob('data/fbank/*'))
tri = sio.loadmat('./train.configs/tri.mat')['tri']
id_name = {}
csv_file = open('data/vox1_meta.csv')
rows=csv.reader(csv_file, delimiter=' ')
headers = next(rows)
for row in rows:
id_name.update({row[0]:row[1]})
available_GT = list(map(lambda k: k.rsplit('/',1)[-1], sorted(glob.glob('data/A2E_val/*'))))
# [TODO] Change this variable to yout result output folder
FOLDER_ROOT = 'supervised_output/'
if not os.path.exists(FOLDER_ROOT):
os.mkdir(FOLDER_ROOT)
coll = []
for folder in voice_list:
index = folder.rsplit('/',1)[-1]
print(index)
if index > 'id10309': # The end of E is 10309
break
corr_name = id_name[index]
if not corr_name in available_GT: #check if the fbank id is in the fitted model database
continue
count = 0
if not os.path.exists(FOLDER_ROOT+corr_name):
os.mkdir(FOLDER_ROOT + corr_name)
all_sequences = sorted(glob.glob(folder+'/*'))
for sequence in all_sequences:
print(sequence)
all_fbanks = sorted(glob.glob(sequence+'/*.npy'))
sequence_name = sequence.rsplit('/',1)[-1]
for fbank in all_fbanks:
print(fbank)
fbank_name = fbank.rsplit('/',1)[-1][:-4]
prediction = voice2face_processed_MeshOut(e_net, g_net, fbank,NETWORKS_PARAMETERS['GPU']).squeeze().detach().cpu()
save_name = FOLDER_ROOT+ corr_name + '/' + sequence_name + '_' + fbank_name
write_obj_with_colors(save_name+'.obj', prediction, triangles=tri)
count += 1
# the first three in all the fbank sequences
if count >= 3:
break
if count >= 3:
break