-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmuses_fusion.py
53 lines (46 loc) · 1.95 KB
/
muses_fusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import argparse
from glob import glob
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('--type', type=str, default='0', help='fusion type') # 0: AvgTrim 1: DupTrim 2: Concat
opt = parser.parse_args()
assert(int(opt.type) >= 0 and int(opt.type) <= 2)
files = glob('I3D_RGB/*')
for i in tqdm(files):
videoFeats = torch.load(i)
audioFeats = torch.load(i.replace('I3D_RGB', 'AudioFeats'))
# AvgTrim
if opt.type == "0":
k = audioFeats.shape[0] // videoFeats.shape[0]
combinedFeats = torch.zeros(1, 128)
for j in range(0, audioFeats.shape[0], 2):
combinedFeats = torch.cat([combinedFeats, torch.mean(audioFeats[j:j+k, :], 0).unsqueeze(0)], dim=0)
combinedFeats = combinedFeats[1:, :]
commonSize = min(combinedFeats.shape[0], videoFeats.shape[0])
videoFeats = videoFeats[: commonSize, :]
combinedFeats = combinedFeats[: commonSize, :]
combinedFeats = torch.cat([videoFeats, combinedFeats], dim=1)
torch.save(combinedFeats, i.replace('I3D_RGB', 'combinedFeats'))
continue
# DupTrim
if opt.type == "1":
if audioFeats.shape[0] > videoFeats.shape[0]:
k = audioFeats.shape[0] // videoFeats.shape[0]
videoFeats = videoFeats.unsqueeze(1).repeat(1, k, 1)
videoFeats = videoFeats.reshape(videoFeats.shape[0] * videoFeats.shape[1], videoFeats.shape[2])
else:
k = videoFeats.shape[0] // audioFeats.shape[0]
audioFeats = audioFeats.unsqueeze(1).repeat(1, k, 1)
audioFeats = audioFeats.reshape(audioFeats.shape[0] * audioFeats.shape[1], audioFeats.shape[2])
commonSize = min(audioFeats.shape[0], videoFeats.shape[0])
videoFeats = videoFeats[: commonSize, :]
audioFeats = audioFeats[: commonSize, :]
combinedFeats = torch.cat([videoFeats, audioFeats], dim=1)
torch.save(combinedFeats, i.replace('I3D_RGB', 'combinedFeats'))
continue
# Concat
if opt.type == "2":
combinedFeats = torch.cat([videoFeats, audioFeats], dim=1)
torch.save(combinedFeats, i.replace('I3D_RGB', 'combinedFeats'))
continue