图1 模型整体框架
使用NLP社区中先进的抽取算法进行三元组的抽取,代码不在这里给出
核心思想是:将树生成问题转化为填表问题,表中任意一个单元格的label表示对于两个三元组之间的关系,当模型填表完成后,通过解码算法将树从表中解码出来,具体如下:
在抽取出文本中的三元组后,使用一个双仿射模型来预测任意两个三元组在诊疗决策树中的关系。两个三元组在决策树中的关系包含四种:
图2 解码框架
解码分为三个步骤:1节点解码、逻辑关系预测与树结构解码,见图2。
节点解码:节点解码的关键在于,处于同一节点的三元组,它们和其他任一三元组的在诊疗决策树中关系都是相同的,因此它们在概率张量
逻辑关系预测:如果多个三元组属于同一节点,需要预测多个三元组之间的逻辑关系。
树结构解码:本文首先将三元组对的概率张量
其中,
解码的最后阶段是将
数据集获取:
将数据置于data/Text2DT下
python Text2DT_TreeDecoder.py \
--config_file config.yml \
--save_dir ckpt/Text2DT \
--data_dir data/Text2DT \
--bert_model_name chinese_wwm_pytorch \
--epochs 100 \
--fine_tune \
--device 0
这里的代码主要是将文本与其中的三元组转化为决策树,不包含三元组抽取的代码与同一节点中三元组之间逻辑关系判断的代码。三元组抽取的代码可以参考NLP社区中先进的抽取算法,同一节点中三元组之间逻辑关系判断我们给出了一个启发式规则的判断方法,位于utils/logic_predictor。
在训练和验证阶段,我们使用三元组抽取的ground truth作为输入,同时在评估时忽略了三元组之间逻辑关系。
在测试阶段,模型输入的格式为
[
{"text": text1,
"triple_list": [tri1a, ..., tri1n]}
{"text": text2,
"triple_list": [tri2a, ..., tri2n]}
]