Skip to content

Commit 4e891a8

Browse files
Jiaqiang-Ruanhunterhectorjasonyanwenl
authored
Support prediction for NER new task. (#331)
* RJQ: [ner_new_task] predict process for ner task * apply cherry-pick yanwen & RJQ: [main-train-tagging] move prediction to folder tagging; Co-authored-by: Hector <hunterhector@gmail.com> Co-authored-by: Yanwen Lin <lyw1124278064@gmail.com>
1 parent 1e3ba68 commit 4e891a8

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
test_path: "data/conll03_english/test"
2+
model_path: "best_crf_model.ckpt"
3+
train_state_path: "train_state.pkl"
4+
batch_size: 10

examples/tagging/evaluator.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2020 The Forte Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# pylint: disable-msg=too-many-locals
15+
"""Evaluator for Conll03 NER tag."""
16+
import os
17+
from pathlib import Path
18+
from forte.data.base_pack import PackType
19+
from forte.evaluation.base import Evaluator
20+
from forte.data.extractor.utils import bio_tagging
21+
from ft.onto.base_ontology import Sentence, Token, EntityMention
22+
23+
24+
def _post_edit(element):
25+
if element[0] is None:
26+
return "O"
27+
return "%s-%s" % (element[1], element[0].ner_type)
28+
29+
30+
def _get_tag(data, pack):
31+
based_on = [pack.get_entry(x) for x in data["Token"]['tid']]
32+
entry = [pack.get_entry(x) for x in data["EntityMention"]['tid']]
33+
tag = bio_tagging(based_on, entry)
34+
tag = [_post_edit(x) for x in tag]
35+
return tag
36+
37+
38+
def _write_tokens_to_file(pred_pack, pred_request,
39+
refer_pack, refer_request,
40+
output_filename):
41+
opened_file = open(output_filename, "w+")
42+
for pred_data, refer_data in zip(
43+
pred_pack.get_data(**pred_request),
44+
refer_pack.get_data(**refer_request)
45+
):
46+
pred_tag = _get_tag(pred_data, pred_pack)
47+
refer_tag = _get_tag(refer_data, refer_pack)
48+
words = refer_data["Token"]["text"]
49+
pos = refer_data["Token"]["pos"]
50+
chunk = refer_data["Token"]["chunk"]
51+
52+
for i, (word, position, chun, tgt, pred) in \
53+
enumerate(zip(words, pos, chunk, refer_tag, pred_tag), 1):
54+
opened_file.write(
55+
"%d %s %s %s %s %s\n" % (i, word, position, chun, tgt, pred)
56+
)
57+
opened_file.write("\n")
58+
opened_file.close()
59+
60+
61+
class CoNLLNEREvaluator(Evaluator):
62+
"""Evaluator for Conll NER task."""
63+
def __init__(self):
64+
super().__init__()
65+
# self.test_component = CoNLLNERPredictor().name
66+
self.output_file = "tmp_eval.txt"
67+
self.score_file = "tmp_eval.score"
68+
self.scores = {}
69+
70+
def consume_next(self, pred_pack: PackType, ref_pack: PackType):
71+
pred_getdata_args = {
72+
"context_type": Sentence,
73+
"request": {
74+
Token: {
75+
"fields": ["chunk", "pos"]
76+
},
77+
EntityMention: {
78+
"fields": ["ner_type"],
79+
},
80+
Sentence: [], # span by default
81+
}
82+
}
83+
84+
refer_getdata_args = {
85+
"context_type": Sentence,
86+
"request": {
87+
Token: {
88+
"fields": ["chunk", "pos", "ner"]
89+
},
90+
EntityMention: {
91+
"fields": ["ner_type"],
92+
},
93+
Sentence: [], # span by default
94+
}
95+
}
96+
97+
_write_tokens_to_file(pred_pack=pred_pack,
98+
pred_request=pred_getdata_args,
99+
refer_pack=ref_pack,
100+
refer_request=refer_getdata_args,
101+
output_filename=self.output_file)
102+
eval_script = \
103+
Path(os.path.abspath(__file__)).parents[2] / \
104+
"forte/utils/eval_scripts/conll03eval.v2"
105+
os.system(f"perl {eval_script} < {self.output_file} > "
106+
f"{self.score_file}")
107+
with open(self.score_file, "r") as fin:
108+
fin.readline()
109+
line = fin.readline()
110+
fields = line.split(";")
111+
acc = float(fields[0].split(":")[1].strip()[:-1])
112+
precision = float(fields[1].split(":")[1].strip()[:-1])
113+
recall = float(fields[2].split(":")[1].strip()[:-1])
114+
f_1 = float(fields[3].split(":")[1].strip())
115+
116+
self.scores = {
117+
"accuracy": acc,
118+
"precision": precision,
119+
"recall": recall,
120+
"f1": f_1,
121+
}
122+
123+
def get_result(self):
124+
return self.scores
+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2020 The Forte Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""This file predict the ner tag for conll03 dataset."""
15+
import sys
16+
import yaml
17+
import torch
18+
from forte.pipeline import Pipeline
19+
from forte.data.readers.conll03_reader_new import CoNLL03Reader
20+
from forte.predictor import Predictor
21+
from ft.onto.base_ontology import Sentence, EntityMention, Token
22+
from examples.tagging.evaluator import CoNLLNEREvaluator
23+
24+
25+
def predict_forward_fn(model, batch):
26+
'''Use model and batch data to predict ner tag.'''
27+
word = batch["text_tag"]["data"]
28+
char = batch["char_tag"]["data"]
29+
word_masks = batch["text_tag"]["masks"][0]
30+
output = model.decode(input_word=word, input_char=char, mask=word_masks)
31+
output = output.numpy()
32+
return {'output_tag': output}
33+
34+
35+
task = sys.argv[1]
36+
assert task in ["ner", "pos"], \
37+
"Not supported nlp task type: {}".format(task)
38+
39+
config_predict = yaml.safe_load(open("configs/config_predict.yml", "r"))
40+
saved_model = torch.load(config_predict['model_path'])
41+
train_state = torch.load(config_predict['train_state_path'])
42+
43+
reader = CoNLL03Reader()
44+
predictor = Predictor(batch_size=config_predict['batch_size'],
45+
model=saved_model,
46+
predict_forward_fn=predict_forward_fn,
47+
feature_resource=train_state['feature_resource'])
48+
evaluator = CoNLLNEREvaluator()
49+
50+
51+
pl = Pipeline()
52+
pl.set_reader(reader)
53+
pl.add(predictor)
54+
pl.add(evaluator)
55+
pl.initialize()
56+
57+
58+
for pack in pl.process_dataset(config_predict['test_path']):
59+
print("---- pack ----")
60+
for instance in pack.get(Sentence):
61+
sent = instance.text
62+
output_tags = []
63+
if task == "ner":
64+
for entry in pack.get(EntityMention, instance):
65+
output_tags.append((entry.text, entry.ner_type))
66+
else:
67+
for entry in pack.get(Token, instance):
68+
output_tags.append((entry.text, entry.pos))
69+
print('---- example -----')
70+
print("sentence: ", sent)
71+
print("output_tags: ", output_tags)
72+
print(evaluator.get_result())

0 commit comments

Comments
 (0)