-
Notifications
You must be signed in to change notification settings - Fork 26
/
duie_v1.py
282 lines (247 loc) · 9.88 KB
/
duie_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#! -*- coding:utf-8 -*-
# 关系抽取任务,基于GPLinker
# 文章介绍:https://kexue.fm/archives/8888
# 数据集:http://ai.baidu.com/broad/download?dataset=sked
import json
import numpy as np
from bert4keras.backend import keras, K
from bert4keras.backend import sparse_multilabel_categorical_crossentropy
from bert4keras.tokenizers import Tokenizer
from bert4keras.layers import EfficientGlobalPointer as GlobalPointer
from bert4keras.models import build_transformer_model
from bert4keras.optimizers import Adam
from bert4keras.snippets import sequence_padding, DataGenerator
from bert4keras.snippets import open, to_array
from tqdm import tqdm
maxlen = 128
batch_size = 64
epochs = 20
config_path = '/root/kg/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
def normalize(text):
"""简单的文本格式化函数
"""
return ' '.join(text.split())
def load_data(filename):
"""加载数据
单条格式:{'text': text, 'spo_list': [(s, p, o)]}
"""
D = []
with open(filename, encoding='utf-8') as f:
for l in f:
l = json.loads(l)
D.append({
'text': normalize(l['text']),
'spoes': [(
normalize(spo['subject']), spo['predicate'],
normalize(spo['object'])
) for spo in l['spo_list']]
})
return D
# 加载数据集
train_data = load_data('/root/kg/datasets/train_data.json')
valid_data = load_data('/root/kg/datasets/dev_data.json')
predicate2id, id2predicate = {}, {}
with open('/root/kg/datasets/all_50_schemas') as f:
for l in f:
l = json.loads(l)
if l['predicate'] not in predicate2id:
id2predicate[len(predicate2id)] = l['predicate']
predicate2id[l['predicate']] = len(predicate2id)
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
class data_generator(DataGenerator):
"""数据生成器
"""
def __iter__(self, random=False):
batch_token_ids, batch_segment_ids = [], []
batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
for is_end, d in self.sample(random):
text = d['text'].lower()
tokens = tokenizer.tokenize(text, maxlen=maxlen)
mapping = tokenizer.rematch(text, tokens)
head_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
tail_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
token_ids = tokenizer.tokens_to_ids(tokens)
segment_ids = [0] * len(token_ids)
# 整理三元组 {(s, o, p)}
spoes = set()
for s, p, o in d['spoes']:
s, o = s.lower(), o.lower()
if s not in text or o not in text:
continue
sh, oh = text.index(s), text.index(o)
st, ot = sh + len(s) - 1, oh + len(o) - 1
if sh in head_mapping and st in tail_mapping:
if oh in head_mapping and ot in tail_mapping:
sh, st = head_mapping[sh], tail_mapping[st]
oh, ot = head_mapping[oh], tail_mapping[ot]
if sh <= st and oh <= ot:
spoes.add((sh, st, predicate2id[p], oh, ot))
# 构建标签
entity_labels = [set() for _ in range(2)]
head_labels = [set() for _ in range(len(predicate2id))]
tail_labels = [set() for _ in range(len(predicate2id))]
for sh, st, p, oh, ot in spoes:
entity_labels[0].add((sh, st))
entity_labels[1].add((oh, ot))
head_labels[p].add((sh, oh))
tail_labels[p].add((st, ot))
for label in entity_labels + head_labels + tail_labels:
if not label: # 至少要有一个标签
label.add((0, 0)) # 如果没有则用0填充
entity_labels = sequence_padding([list(l) for l in entity_labels])
head_labels = sequence_padding([list(l) for l in head_labels])
tail_labels = sequence_padding([list(l) for l in tail_labels])
# 构建batch
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_entity_labels.append(entity_labels)
batch_head_labels.append(head_labels)
batch_tail_labels.append(tail_labels)
if len(batch_token_ids) == self.batch_size or is_end:
batch_token_ids = sequence_padding(batch_token_ids)
batch_segment_ids = sequence_padding(batch_segment_ids)
batch_entity_labels = sequence_padding(
batch_entity_labels, seq_dims=2
)
batch_head_labels = sequence_padding(
batch_head_labels, seq_dims=2
)
batch_tail_labels = sequence_padding(
batch_tail_labels, seq_dims=2
)
yield [batch_token_ids, batch_segment_ids], [
batch_entity_labels, batch_head_labels, batch_tail_labels
]
batch_token_ids, batch_segment_ids = [], []
batch_entity_labels, batch_head_labels, batch_tail_labels = [], [], []
def globalpointer_crossentropy(y_true, y_pred):
"""给GlobalPointer设计的交叉熵
"""
shape = K.shape(y_pred)
y_true = y_true[..., 0] * K.cast(shape[2], K.floatx()) + y_true[..., 1]
y_pred = K.reshape(y_pred, (shape[0], -1, K.prod(shape[2:])))
loss = sparse_multilabel_categorical_crossentropy(y_true, y_pred, True)
return K.mean(K.sum(loss, axis=1))
# 加载预训练模型
base = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
return_keras_model=False
)
# 预测结果
entity_output = GlobalPointer(heads=2, head_size=64)(base.model.output)
head_output = GlobalPointer(
heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
tail_output = GlobalPointer(
heads=len(predicate2id), head_size=64, RoPE=False, tril_mask=False
)(base.model.output)
outputs = [entity_output, head_output, tail_output]
# 构建模型
model = keras.models.Model(base.model.inputs, outputs)
model.compile(loss=globalpointer_crossentropy, optimizer=Adam(2e-5))
model.summary()
def extract_spoes(text, threshold=0):
"""抽取输入text所包含的三元组
"""
tokens = tokenizer.tokenize(text, maxlen=maxlen)
mapping = tokenizer.rematch(text, tokens)
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
token_ids, segment_ids = to_array([token_ids], [segment_ids])
outputs = model.predict([token_ids, segment_ids])
outputs = [o[0] for o in outputs]
# 抽取subject和object
subjects, objects = set(), set()
outputs[0][:, [0, -1]] -= np.inf
outputs[0][:, :, [0, -1]] -= np.inf
for l, h, t in zip(*np.where(outputs[0] > threshold)):
if l == 0:
subjects.add((h, t))
else:
objects.add((h, t))
# 识别对应的predicate
spoes = set()
for sh, st in subjects:
for oh, ot in objects:
p1s = np.where(outputs[1][:, sh, oh] > threshold)[0]
p2s = np.where(outputs[2][:, st, ot] > threshold)[0]
ps = set(p1s) & set(p2s)
for p in ps:
spoes.add((
text[mapping[sh][0]:mapping[st][-1] + 1], id2predicate[p],
text[mapping[oh][0]:mapping[ot][-1] + 1]
))
return list(spoes)
class SPO(tuple):
"""用来存三元组的类
表现跟tuple基本一致,只是重写了 __hash__ 和 __eq__ 方法,
使得在判断两个三元组是否等价时容错性更好。
"""
def __init__(self, spo):
self.spox = (
tuple(tokenizer.tokenize(spo[0])),
spo[1],
tuple(tokenizer.tokenize(spo[2])),
)
def __hash__(self):
return self.spox.__hash__()
def __eq__(self, spo):
return self.spox == spo.spox
def evaluate(data):
"""评估函数,计算f1、precision、recall
"""
X, Y, Z = 1e-10, 1e-10, 1e-10
f = open('dev_pred.json', 'w', encoding='utf-8')
pbar = tqdm()
for d in data:
R = set([SPO(spo) for spo in extract_spoes(d['text'])])
T = set([SPO(spo) for spo in d['spo_list']])
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
pbar.update()
pbar.set_description(
'f1: %.5f, precision: %.5f, recall: %.5f' % (f1, precision, recall)
)
s = json.dumps({
'text': d['text'],
'spo_list': list(T),
'spo_list_pred': list(R),
'new': list(R - T),
'lack': list(T - R),
},
ensure_ascii=False,
indent=4)
f.write(s + '\n')
pbar.close()
f.close()
return f1, precision, recall
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, epoch, logs=None):
f1, precision, recall = evaluate(valid_data)
if f1 >= self.best_val_f1:
self.best_val_f1 = f1
model.save_weights('best_model.weights')
print(
'f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
(f1, precision, recall, self.best_val_f1)
)
if __name__ == '__main__':
train_generator = data_generator(train_data, batch_size)
evaluator = Evaluator()
model.fit(
train_generator.forfit(),
steps_per_epoch=len(train_generator),
epochs=epochs,
callbacks=[evaluator]
)
else:
model.load_weights('best_model.weights')