-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
33 lines (26 loc) · 1.25 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from utils import load_model, extend_maps, prepocess_data_for_lstmcrf
from data import build_corpus
from evaluating import Metrics
BiLSTM_MODEL_PATH = './ckpts/bilstm.pkl'
BiLSTMCRF_MODEL_PATH = './ckpts/bilstm_crf.pkl'
REMOVE_O = False # 在评估的时候是否去除O标记
def main():
print("读取数据...")
train_word_lists, train_tag_lists, word2id, tag2id = \
build_corpus("train")
# dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)
print("加载并评估bilstm+crf模型...")
crf_word2id, crf_tag2id = extend_maps(word2id, tag2id, for_crf=True)
bilstm_model = load_model(BiLSTMCRF_MODEL_PATH)
bilstm_model.model.bilstm.bilstm.flatten_parameters() # remove warning
test_word_lists, test_tag_lists = prepocess_data_for_lstmcrf(
test_word_lists, test_tag_lists, test=True
)
lstmcrf_pred, target_tag_list = bilstm_model.test(test_word_lists, test_tag_lists,
crf_word2id, crf_tag2id)
metrics = Metrics(target_tag_list, lstmcrf_pred, remove_O=REMOVE_O)
metrics.report_scores()
metrics.report_confusion_matrix()
if __name__ == "__main__":
main()