Skip to content

Commit

Permalink
Enable passing args to tf.k.Model.predict() when calling predict.clas…
Browse files Browse the repository at this point in the history
…sify()

Fixes #113
  • Loading branch information
colindean authored Feb 24, 2023
1 parent 80bc87d commit 9d7cbd4
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions nsfw_detector/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,24 @@ def load_model(model_path):
return model


def classify(model, input_paths, image_dim=IMAGE_DIM):
""" Classify given a model, input paths (could be single string), and image dimensionality...."""
def classify(model, input_paths, image_dim=IMAGE_DIM, predict_args={}):
"""
Classify given a model, input paths (could be single string), and image dimensionality.
Optionally, pass predict_args that will be passed to tf.keras.Model.predict().
"""
images, image_paths = load_images(input_paths, (image_dim, image_dim))
probs = classify_nd(model, images)
probs = classify_nd(model, images, predict_args)
return dict(zip(image_paths, probs))


def classify_nd(model, nd_images):
""" Classify given a model, image array (numpy)...."""

model_preds = model.predict(nd_images)
def classify_nd(model, nd_images, predict_args={}):
"""
Classify given a model, image array (numpy)
Optionally, pass predict_args that will be passed to tf.keras.Model.predict().
"""
model_preds = model.predict(nd_images, **predict_args)
# preds = np.argsort(model_preds, axis = 1).tolist()

categories = ['drawings', 'hentai', 'neutral', 'porn', 'sexy']
Expand Down

0 comments on commit 9d7cbd4

Please sign in to comment.