-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.py
32 lines (25 loc) · 1.16 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import time
from models.bilstm_crf import BILSTM_Model
from utils import save_model
from evaluating import Metrics
def bilstm_train_and_eval(train_data, dev_data, test_data,
word2id, tag2id, crf=True, remove_O=False):
train_word_lists, train_tag_lists = train_data
dev_word_lists, dev_tag_lists = dev_data
test_word_lists, test_tag_lists = test_data
start = time.time()
vocab_size = len(word2id)
out_size = len(tag2id)
bilstm_model = BILSTM_Model(vocab_size, out_size, crf=crf)
bilstm_model.train(train_word_lists, train_tag_lists,
dev_word_lists, dev_tag_lists, word2id, tag2id)
model_name = "bilstm_crf" if crf else "bilstm"
save_model(bilstm_model, "./ckpts/" + model_name + ".pkl")
print("训练完毕,共用时{}秒.".format(int(time.time() - start)))
print("评估{}模型中...".format(model_name))
pred_tag_lists, test_tag_lists = bilstm_model.test(
test_word_lists, test_tag_lists, word2id, tag2id)
metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
metrics.report_scores()
metrics.report_confusion_matrix()
return pred_tag_lists