-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpredict.py
99 lines (91 loc) · 3.82 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding:UTF-8 -*-
from src.predictor import Predictor
import re
from argparse import ArgumentParser
from tqdm import tqdm
import time
import deepspeed
def read_batch(path, batch_size, segmented):
batch = []
with open(path, "r", encoding="utf8") as fr:
while True:
try:
line = next(fr)
except StopIteration:
yield batch
break
line = line.strip()
if segmented:
line = line.split(" ")
else:
line = list(line)
if len(batch) < batch_size:
batch.append(line)
else:
yield batch
batch = [line]
def detokenize(text: str):
text = re.sub(" ##(?=\S)", "", text)
text = re.sub("\s+", "", text)
# text = re.sub("(?<![a-zA-Z]) | (?![a-zA-Z])", "", text)
return text
def main(args):
predictor = Predictor(args)
total_corrections = []
cnt_corrections = 0
print(f"model path: {args.ckpt_path}")
print("start predicting ...")
s = time.time()
if args.out_path:
fw = open(args.out_path, "w", encoding="utf8")
for batch_text in tqdm(read_batch(args.input_path, args.batch_size, args.segmented)):
pred_batch, cnt = predictor.handle_batch(batch_text)
cnt_corrections += cnt
if args.out_path:
for idx, pred_tokens in enumerate(pred_batch):
pred_line = " ".join(pred_tokens)
if not re.search("[^ #]", pred_line):
pred_line = " ".join(batch_text[idx])
print(
"prediction for current line is none, thus we replace it with source line...")
print(args.ckpt_id)
print(pred_line)
if bool(args.detokenize):
pred_line = detokenize(pred_line)
fw.write(pred_line+"\n")
else:
for pred_tokens in pred_batch:
pred_line = " ".join(pred_tokens)
if bool(args.detokenize):
pred_line = detokenize(pred_line)
print(pred_line)
e = time.time()
if args.out_path:
fw.close()
print(f"total cost: {e -s }s")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--local_rank", type=int)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--iteration_count", type=int, default=5)
parser.add_argument("--min_seq_len", type=int, default=3, help="<= min_seq_len will be skipped")
parser.add_argument("--max_num_tokens", type=int, default=128, help="max seq length after tokenization")
parser.add_argument("--min_error_probability", type=float, default=0.0)
parser.add_argument("--additional_confidence", type=float, default=0.0)
parser.add_argument("--sub_token_mode", type=str, default="average")
parser.add_argument("--max_pieces_per_token", type=int, default=5)
parser.add_argument("--unk2keep", type=int, default=0, help="replace oov label with keep")
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--detect_vocab_path", type=str, required=True)
parser.add_argument("--correct_vocab_path", type=str, required=True)
parser.add_argument("--pretrained_transformer_path",
type=str, required=True)
parser.add_argument("--input_path", type=str, required=True)
parser.add_argument("--out_path", type=str, default=None)
parser.add_argument("--special_tokens_fix", type=int, default=0)
parser.add_argument("--segmented", type=int, default=0)
parser.add_argument("--detokenize", type=int, default=0)
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
main(args)