-
Notifications
You must be signed in to change notification settings - Fork 479
/
validate_cnn.py
50 lines (41 loc) · 1.41 KB
/
validate_cnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""
Classify a few images through our CNN.
"""
import numpy as np
import operator
import random
import glob
import os.path
from data import DataSet
from processor import process_image
from keras.models import load_model
def main(nb_images=5):
"""Spot-check `nb_images` images."""
data = DataSet()
model = load_model('data/checkpoints/inception.057-1.16.hdf5')
# Get all our test images.
images = glob.glob(os.path.join('data', 'test', '**', '*.jpg'))
for _ in range(nb_images):
print('-'*80)
# Get a random row.
sample = random.randint(0, len(images) - 1)
image = images[sample]
# Turn the image into an array.
print(image)
image_arr = process_image(image, (299, 299, 3))
image_arr = np.expand_dims(image_arr, axis=0)
# Predict.
predictions = model.predict(image_arr)
# Show how much we think it's each one.
label_predictions = {}
for i, label in enumerate(data.classes):
label_predictions[label] = predictions[0][i]
sorted_lps = sorted(label_predictions.items(), key=operator.itemgetter(1), reverse=True)
for i, class_prediction in enumerate(sorted_lps):
# Just get the top five.
if i > 4:
break
print("%s: %.2f" % (class_prediction[0], class_prediction[1]))
i += 1
if __name__ == '__main__':
main()