-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtest.py
85 lines (61 loc) · 1.95 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import os
from utils import util
import multiprocessing
from predictor import Predictor
data_path = "input_path" # The directory of the input data
output_path = "output_path" # The directory of the output data
def format_result(result):
rex = {"accusation": [], "articles": [], "imprisonment": -3}
res_acc = []
for x in result["accusation"]:
if not (x is None):
res_acc.append(int(x))
rex["accusation"] = res_acc
if not (result["imprisonment"] is None):
rex["imprisonment"] = int(result["imprisonment"])
else:
rex["imprisonment"] = -3
res_art = []
for x in result["articles"]:
if not (x is None):
res_art.append(int(x))
rex["articles"] = res_art
return rex
if __name__ == "__main__":
user = Predictor()
cnt = 0
def get_batch():
v = user.batch_size
if not (type(v) is int) or v <= 0:
raise NotImplementedError
return v
def solve(fact):
result = user.predict(fact)
for a in range(0, len(result)):
result[a] = format_result(result[a])
return result
for file_name in os.listdir(data_path):
inf = open(os.path.join(data_path, file_name), "r")
ouf = open(os.path.join(output_path, file_name), "w")
fact = []
for line in inf:
fact.append(json.loads(line)["fact"])
if len(fact) == get_batch():
result = solve(fact)
cnt += len(result)
for x in result:
print(json.dumps(x), file=ouf)
fact = []
if len(fact) != 0:
result = solve(fact)
cnt += len(result)
for x in result:
print(json.dumps(x), file=ouf)
fact = []
inf.close()
ouf.close()
if util.DEBUG:
print("DEBUG: prediction work finished.")