-
Notifications
You must be signed in to change notification settings - Fork 1
/
classify_nsfw.py
70 lines (49 loc) · 2.28 KB
/
classify_nsfw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
#python classify_nsfw.py -m data/open_nsfw-weights.npy test.jpg
import sys
import argparse
import tensorflow as tf
from model import OpenNsfwModel, InputType
from image_utils import create_tensorflow_image_loader
from image_utils import create_yahoo_image_loader
import numpy as np
IMAGE_LOADER_TENSORFLOW = "tensorflow"
IMAGE_LOADER_YAHOO = "yahoo"
def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("input_file", help="Path to the input image.\
Only jpeg images are supported.")
parser.add_argument("-m", "--model_weights", required=True,
help="Path to trained model weights file")
parser.add_argument("-l", "--image_loader",
default=IMAGE_LOADER_YAHOO,
help="image loading mechanism",
choices=[IMAGE_LOADER_YAHOO, IMAGE_LOADER_TENSORFLOW])
parser.add_argument("-i", "--input_type",
default=InputType.TENSOR.name.lower(),
help="input type",
choices=[InputType.TENSOR.name.lower(),
InputType.BASE64_JPEG.name.lower()])
args = parser.parse_args()
model = OpenNsfwModel()
with tf.Session() as sess:
input_type = InputType[args.input_type.upper()]
model.build(weights_path=args.model_weights, input_type=input_type)
fn_load_image = None
if input_type == InputType.TENSOR:
if args.image_loader == IMAGE_LOADER_TENSORFLOW:
fn_load_image = create_tensorflow_image_loader(tf.Session(graph=tf.Graph()))
else:
fn_load_image = create_yahoo_image_loader()
elif input_type == InputType.BASE64_JPEG:
import base64
fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
sess.run(tf.global_variables_initializer())
image = fn_load_image(args.input_file)
predictions = \
sess.run(model.predictions,
feed_dict={model.input: image})
print("Results for '{}'".format(args.input_file))
print("\tSFW score:\t{}\n\tNSFW score:\t{}".format(*predictions[0]))
if __name__ == "__main__":
main(sys.argv)