forked from GantMan/nsfw_model
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added code for loading model and running predictions in Keras
- Loading branch information
1 parent
4055bb1
commit 5441bd4
Showing
1 changed file
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import keras | ||
import numpy as np | ||
|
||
|
||
def load_images(image_paths, image_size): | ||
''' | ||
Function for loading images into numpy arrays for passing to model.predict | ||
inputs: | ||
image_paths: list of image paths to load | ||
image_size: size into which images should be resized | ||
outputs: | ||
loaded_images: loaded images on which keras model can run predictions | ||
loaded_image_indexes: paths of images which the function is able to process | ||
''' | ||
loaded_images = [] | ||
loaded_image_paths = [] | ||
|
||
for i, img_path in enumerate(image_paths): | ||
try: | ||
image = keras.preprocessing.image.load_img(img_path, target_size = image_size) | ||
image = keras.preprocessing.image.img_to_array(image) | ||
image /= 255 | ||
loaded_images.append(image) | ||
loaded_image_paths.append(img_path) | ||
except Exception as ex: | ||
print(i, img_path, ex) | ||
|
||
return np.asarray(loaded_images), loaded_image_paths | ||
|
||
class keras_predictor(): | ||
''' | ||
Class for loading model and running predictions. | ||
For example on how to use take a look the if __name__ == '__main__' part. | ||
''' | ||
nsfw_model = None | ||
|
||
def __init__(self, model_path): | ||
''' | ||
model = keras_predictor('path_to_weights') | ||
''' | ||
keras_predictor.nsfw_model = keras.models.load_model(model_path) | ||
|
||
|
||
def predict(self, image_paths = [], batch_size = 32, image_size = (299, 299), categories = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']): | ||
''' | ||
inputs: | ||
image_paths: list of image paths or can be a string too (for single image) | ||
batch_size: batch_size for running predictions | ||
image_size: size to which the image needs to be resized | ||
categories: since the model predicts numbers, categories is the list of actual names of categories | ||
''' | ||
if isinstance(image_paths, str): | ||
image_paths = [image_paths] | ||
|
||
loaded_images, loaded_image_paths = load_images(image_paths, image_size) | ||
|
||
if not loaded_image_paths: | ||
return {} | ||
|
||
model_preds = keras_predictor.nsfw_model.predict(loaded_images, batch_size = batch_size) | ||
preds = np.argmax(model_preds, axis = 1) | ||
|
||
probs = [] | ||
for i, pred in enumerate(preds): | ||
probs.append(model_preds[i][pred]) | ||
|
||
preds = [categories[pred] for pred in preds] | ||
|
||
images_preds = {} | ||
|
||
for i, loaded_image_path in enumerate(loaded_image_paths): | ||
images_preds[loaded_image_path] = {'class': preds[i], 'prob': probs[i]} | ||
|
||
return images_preds | ||
|
||
|
||
if __name__ == '__main__': | ||
print('\n Enter path for the keras weights, leave empty to use "./nsfw.299x299.h5" \n') | ||
weights_path = input().strip() | ||
if not weights_path: weights_path = "./nsfw.299x299.h5" | ||
|
||
m = keras_predictor(weights_path) | ||
|
||
while 1: | ||
print('\n Enter single image path or multiple images seperated by || (2 pipes) \n') | ||
images = input().split('||') | ||
images = [image.strip() for image in images] | ||
print(m.predict(images), '\n') |