Skip to content

Commit

Permalink
added code for loading model and running predictions in Keras
Browse files Browse the repository at this point in the history
  • Loading branch information
bedapudi6788 committed Feb 14, 2019
1 parent 4055bb1 commit 5441bd4
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions keras_predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import keras
import numpy as np


def load_images(image_paths, image_size):
'''
Function for loading images into numpy arrays for passing to model.predict
inputs:
image_paths: list of image paths to load
image_size: size into which images should be resized
outputs:
loaded_images: loaded images on which keras model can run predictions
loaded_image_indexes: paths of images which the function is able to process
'''
loaded_images = []
loaded_image_paths = []

for i, img_path in enumerate(image_paths):
try:
image = keras.preprocessing.image.load_img(img_path, target_size = image_size)
image = keras.preprocessing.image.img_to_array(image)
image /= 255
loaded_images.append(image)
loaded_image_paths.append(img_path)
except Exception as ex:
print(i, img_path, ex)

return np.asarray(loaded_images), loaded_image_paths

class keras_predictor():
'''
Class for loading model and running predictions.
For example on how to use take a look the if __name__ == '__main__' part.
'''
nsfw_model = None

def __init__(self, model_path):
'''
model = keras_predictor('path_to_weights')
'''
keras_predictor.nsfw_model = keras.models.load_model(model_path)


def predict(self, image_paths = [], batch_size = 32, image_size = (299, 299), categories = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']):
'''
inputs:
image_paths: list of image paths or can be a string too (for single image)
batch_size: batch_size for running predictions
image_size: size to which the image needs to be resized
categories: since the model predicts numbers, categories is the list of actual names of categories
'''
if isinstance(image_paths, str):
image_paths = [image_paths]

loaded_images, loaded_image_paths = load_images(image_paths, image_size)

if not loaded_image_paths:
return {}

model_preds = keras_predictor.nsfw_model.predict(loaded_images, batch_size = batch_size)
preds = np.argmax(model_preds, axis = 1)

probs = []
for i, pred in enumerate(preds):
probs.append(model_preds[i][pred])

preds = [categories[pred] for pred in preds]

images_preds = {}

for i, loaded_image_path in enumerate(loaded_image_paths):
images_preds[loaded_image_path] = {'class': preds[i], 'prob': probs[i]}

return images_preds


if __name__ == '__main__':
print('\n Enter path for the keras weights, leave empty to use "./nsfw.299x299.h5" \n')
weights_path = input().strip()
if not weights_path: weights_path = "./nsfw.299x299.h5"

m = keras_predictor(weights_path)

while 1:
print('\n Enter single image path or multiple images seperated by || (2 pipes) \n')
images = input().split('||')
images = [image.strip() for image in images]
print(m.predict(images), '\n')

0 comments on commit 5441bd4

Please sign in to comment.