From 9d7cbd40d67a93ba3b2845b841cd5ff4b5c91b68 Mon Sep 17 00:00:00 2001 From: Colin Dean Date: Thu, 23 Feb 2023 22:44:38 -0500 Subject: [PATCH] Enable passing args to tf.k.Model.predict() when calling predict.classify() Fixes #113 --- nsfw_detector/predict.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/nsfw_detector/predict.py b/nsfw_detector/predict.py index 4a63809..9740a9d 100644 --- a/nsfw_detector/predict.py +++ b/nsfw_detector/predict.py @@ -57,17 +57,24 @@ def load_model(model_path): return model -def classify(model, input_paths, image_dim=IMAGE_DIM): - """ Classify given a model, input paths (could be single string), and image dimensionality....""" +def classify(model, input_paths, image_dim=IMAGE_DIM, predict_args={}): + """ + Classify given a model, input paths (could be single string), and image dimensionality. + + Optionally, pass predict_args that will be passed to tf.keras.Model.predict(). + """ images, image_paths = load_images(input_paths, (image_dim, image_dim)) - probs = classify_nd(model, images) + probs = classify_nd(model, images, predict_args) return dict(zip(image_paths, probs)) -def classify_nd(model, nd_images): - """ Classify given a model, image array (numpy)....""" - - model_preds = model.predict(nd_images) +def classify_nd(model, nd_images, predict_args={}): + """ + Classify given a model, image array (numpy) + + Optionally, pass predict_args that will be passed to tf.keras.Model.predict(). + """ + model_preds = model.predict(nd_images, **predict_args) # preds = np.argsort(model_preds, axis = 1).tolist() categories = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']