Skip to content

Commit a4193c9

Browse files
committed
update
1 parent 99bbada commit a4193c9

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

cnn_class2/class_activation_maps.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# https://deeplearningcourses.com/c/advanced-computer-vision
2+
# https://www.udemy.com/advanced-computer-vision
3+
4+
from __future__ import print_function, division
5+
from builtins import range, input
6+
# Note: you may need to update your version of future
7+
# sudo pip install -U future
8+
9+
from keras.layers import Input, Lambda, Dense, Flatten
10+
from keras.models import Model
11+
from keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
12+
# from keras.applications.inception_v3 import InceptionV3, preprocess_input
13+
from keras.preprocessing import image
14+
from keras.preprocessing.image import ImageDataGenerator
15+
16+
from sklearn.metrics import confusion_matrix
17+
import numpy as np
18+
import scipy as sp
19+
import matplotlib.pyplot as plt
20+
21+
from glob import glob
22+
23+
from skimage.transform import rescale, resize
24+
25+
26+
27+
# useful for getting number of files
28+
image_files = glob('../large_files/256_ObjectCategories/*/*.jp*g')
29+
image_files += glob('../large_files/101_ObjectCategories/*/*.jp*g')
30+
31+
32+
33+
# look at an image for fun
34+
plt.imshow(image.load_img(np.random.choice(image_files)))
35+
plt.show()
36+
37+
38+
# add preprocessing layer to the front of VGG
39+
resnet = ResNet50(input_shape=(224, 224, 3), weights='imagenet', include_top=True)
40+
41+
# view the structure of the model
42+
# if you want to confirm we need activation_49
43+
resnet.summary()
44+
45+
# make a model to get output before flatten
46+
activation_layer = resnet.get_layer('activation_49')
47+
48+
# create a model object
49+
model = Model(inputs=resnet.input, outputs=activation_layer.output)
50+
51+
# get the feature map weights
52+
final_dense = resnet.get_layer('fc1000')
53+
W = final_dense.get_weights()[0]
54+
55+
56+
while True:
57+
img = image.load_img(np.random.choice(image_files), target_size=(224, 224))
58+
x = preprocess_input(np.expand_dims(img, 0))
59+
fmaps = model.predict(x)[0] # 7 x 7 x 2048
60+
61+
# get predicted class
62+
probs = resnet.predict(x)
63+
classnames = decode_predictions(probs)[0]
64+
print(classnames)
65+
classname = classnames[0][1]
66+
pred = np.argmax(probs[0])
67+
68+
# get the 2048 weights for the relevant class
69+
w = W[:, pred]
70+
71+
# "dot" w with fmaps
72+
cam = fmaps.dot(w)
73+
74+
# upsample to 224 x 224
75+
cam = sp.ndimage.zoom(cam, (32, 32), order=1)
76+
77+
plt.subplot(1,2,1)
78+
plt.imshow(img, alpha=0.8)
79+
plt.imshow(cam, cmap='jet', alpha=0.5)
80+
plt.subplot(1,2,2)
81+
plt.imshow(img)
82+
plt.title(classname)
83+
plt.show()
84+
85+
ans = input("Continue? (Y/n)")
86+
if ans and ans[0].lower() == 'n':
87+
break
88+
89+
90+
91+
# def slowversion(A, w):
92+
# N = len(w)
93+
# result = np.zeros(A.shape[:-1])
94+
# for i in range(N):
95+
# result += A[:,:,i]*w[i]
96+
# return result
97+

0 commit comments

Comments
 (0)