-
Notifications
You must be signed in to change notification settings - Fork 106
/
cam.py
61 lines (54 loc) · 2.75 KB
/
cam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from keras.models import *
from keras.callbacks import *
import keras.backend as K
from model import *
from data import *
import cv2
import argparse
def train(dataset_path):
model = get_model()
X, y = load_inria_person(dataset_path)
print "Training.."
checkpoint_path="weights.{epoch:02d}-{val_loss:.2f}.hdf5"
checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto')
model.fit(X, y, nb_epoch=40, batch_size=32, validation_split=0.2, verbose=1, callbacks=[checkpoint])
def visualize_class_activation_map(model_path, img_path, output_path):
model = load_model(model_path)
original_img = cv2.imread(img_path, 1)
width, height, _ = original_img.shape
#Reshape to the network input shape (3, w, h).
img = np.array([np.transpose(np.float32(original_img), (2, 0, 1))])
#Get the 512 input weights to the softmax.
class_weights = model.layers[-1].get_weights()[0]
final_conv_layer = get_output_layer(model, "conv5_3")
get_output = K.function([model.layers[0].input], [final_conv_layer.output, model.layers[-1].output])
[conv_outputs, predictions] = get_output([img])
conv_outputs = conv_outputs[0, :, :, :]
#Create the class activation map.
cam = np.zeros(dtype = np.float32, shape = conv_outputs.shape[1:3])
for i, w in enumerate(class_weights[:, 1]):
cam += w * conv_outputs[i, :, :]
print "predictions", predictions
cam /= np.max(cam)
cam = cv2.resize(cam, (height, width))
heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
heatmap[np.where(cam < 0.2)] = 0
img = heatmap*0.5 + original_img
cv2.imwrite(output_path, img)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--train", type = bool, default = False, help = 'Train the network or visualize a CAM')
parser.add_argument("--image_path", type = str, help = "Path of an image to run the network on")
parser.add_argument("--output_path", type = str, default = "heatmap.jpg", help = "Path of an image to run the network on")
parser.add_argument("--model_path", type = str, help = "Path of the trained model")
parser.add_argument("--dataset_path", type = str, help = \
'Path to image dataset. Should have pos/neg folders, like in the inria person dataset. \
http://pascal.inrialpes.fr/data/human/')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
if args.train:
train(args.dataset_path)
else:
visualize_class_activation_map(args.model_path, args.image_path, args.output_path)