-
Notifications
You must be signed in to change notification settings - Fork 0
/
feature_saliency.py
120 lines (101 loc) · 3.77 KB
/
feature_saliency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import time
import os
import sys
import torch.nn as nn
import torchvision.transforms as transforms
import json
from mean import get_mean, get_std
from PIL import Image
import cv2
from datasets.ucf101 import load_annotation_data
from datasets.ucf101 import get_class_labels
from model import generate_model
from utils import AverageMeter
from opts import parse_opts
from spatial_transforms import (
Compose, Normalize, Scale, CenterCrop, CornerCrop, MultiScaleCornerCrop,
MultiScaleRandomCrop, RandomHorizontalFlip, ToTensor)
from temporal_transforms import LoopPadding, TemporalRandomCrop
from target_transforms import ClassLabel, VideoID
from target_transforms import Compose as TargetCompose
import numpy as np
from plotSaliency import plotHeatMapExampleWise
def resume_model(opt, model):
""" Resume model
"""
checkpoint = torch.load(opt.resume_path, map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
def saliency(clip, model, target_class):
if opt.no_mean_norm and not opt.std_norm:
norm_method = Normalize([0, 0, 0], [1, 1, 1])
elif not opt.std_norm:
norm_method = Normalize(opt.mean, [1, 1, 1])
else:
norm_method = Normalize(opt.mean, opt.std)
spatial_transform = Compose([
Scale((150, 150)),
ToTensor(opt.norm_value), norm_method
])
if spatial_transform is not None:
clip = [spatial_transform(img) for img in clip]
model.zero_grad()
clip = torch.stack(clip, dim=0)
clip = clip.unsqueeze(0)
clip.requires_grad = True
outputs = model(clip, training = False)
resout = model.resnet_out
print("resout shape = ", resout[0].shape)
outputs = F.softmax(outputs, dim = 1)
outputs, _ = torch.topk(outputs, k=1)
print(outputs)
outputs.backward()
grad = [] #resout.grad.data.cpu().numpy()
saliency = [] #np.abs(grad)
for i in range(clip.shape[1]):
res_grad = resout[i].grad.data.cpu().numpy()
print("res_grad shape = ", res_grad.shape)
grad.append(res_grad)
saliency.append(np.absolute(res_grad))
# print("Saliency = ", saliency)
print("sal length = ", len(saliency))
print("sal ele length = ", saliency[0].shape)
return np.array(grad).squeeze(), np.array(saliency).squeeze()
def saveSaliency(saliency, folder_name, file_name):
np.save(folder_name + file_name, saliency)
print("Saliency maps saved as ", file_name)
if __name__ == "__main__":
opt = parse_opts()
print(opt)
save_folder = "./saliency/temporal/"
data = load_annotation_data(opt.annotation_path)
class_to_idx = get_class_labels(data)
device = torch.device("cpu")
print(class_to_idx)
idx_to_class = {}
for name, label in class_to_idx.items():
idx_to_class[label] = name
model = generate_model(opt, device)
ModelTypes = "lstm_cnn"
if opt.resume_path:
resume_model(opt, model)
opt.mean = get_mean(opt.norm_value, dataset=opt.mean_dataset)
opt.std = get_std(opt.norm_value)
cam = cv2.VideoCapture('./data/kth_trimmed_data/running/0_person01_running_d1_uncomp.avi')
total_frames = int(cam.get(cv2.CAP_PROP_FRAME_COUNT))
N = total_frames-1
clip = []
for i in range(total_frames):
ret, img = cam.read()
if len(clip) == N:
grad, sal = saliency(clip, model, 1)
saveSaliency(sal.squeeze(), save_folder, ModelTypes)
clip = []
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
clip.append(img)
#plot saliency
sal = (255 * (sal / np.max(sal))).astype(np.uint8)
plotHeatMapExampleWise(sal.T, ModelTypes, save_folder)