-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassification_pred.py
executable file
·82 lines (51 loc) · 2.11 KB
/
classification_pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/usr/bin/env python2
import os
import imp
import argparse
from convnetskeras.convnets import preprocess_image_batch
import ipdb
def main(model_file, img_list, weights_path, outputFile,
batch_size=32, batch_per_cache=None):
imp.load_source("convnet", model_file)
from convnet import model
model.load_weights(weights_path)
data = []
with open(img_list, "r") as f:
for l in f:
data.append(l[:-1])
print("File paths loaded")
output = []
i = 0
#while i < len(data):
f = open(outputFile, "w")
for i in range(0, len(data), batch_size):
print i
try:
X = preprocess_image_batch(data[i:i+batch_size],
img_size=(224,224))
except:
print("BUG")
continue
Y_pred = model.predict_on_batch(X)
output.extend(zip(data[i:i+batch_size], list(Y_pred)))
for j, path_img in enumerate(data[i:i+batch_size]):
#ipdb.set_trace()
f.write(os.path.basename(path_img)+";")
f.write(";".join((str(l)+";"+str(s) for (l,s) in enumerate(list(Y_pred[j])))))
f.write("\n")
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("MODELS",
help = ("Model script. They should all define a model"
"variable"))
parser.add_argument("DATA", help = "jpg files to classify")
parser.add_argument("WEIGHTSPATH", help="path to folder for trained weights")
parser.add_argument("OUTPUT", help="Output Csv file")
parser.add_argument("-bs", "--batchsize", type=int, default=32,
help="Batch size. Default : 32")
parser.add_argument("-bpc", "--batchpercache", type=int, default=100,
help="Number of batch in a cache. Default : 100")
args = parser.parse_args()
main(args.MODELS, args.DATA, args.WEIGHTSPATH, args.OUTPUT,
batch_size=args.batchsize, batch_per_cache=args.batchpercache)