Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add triple extraction #59

Merged
merged 1 commit into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/text_triple_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2021 OpenKS Authors, DCD Research Lab, Zhejiang University.
# All Rights Reserved.

import argparse
from openks.models.pytorch import semeval_constant as constant
from openks.loaders import loader_config, SourceType, FileType, Loader
from openks.models import OpenKSModel

''' 载入数据 '''
# TODO
dataset = None

''' 文本信息抽取模型训练 '''
# 列出已加载模型
OpenKSModel.list_modules()
# 算法模型选择配置
parser = argparse.ArgumentParser(description='Triple args.')
parser.add_argument("--model", default=None, type=str, required=True)
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument("--eval_per_epoch", default=10, type=int,
help="How many times it evaluates on dev set per epoch")
parser.add_argument("--num_generated_triples", default=3, type=int,
help="How many triples it generates for one instance")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--negative_label", default="no_relation", type=str)
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--train_file", default=None, type=str, help="The path of the training data.")
parser.add_argument("--train_mode", type=str, default='random_sorted', choices=['random', 'sorted', 'random_sorted'])
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--eval_test", action="store_true", help="Whether to evaluate on final test set.")
parser.add_argument("--eval_with_gold", action="store_true", help="Whether to evaluate the relation model with gold entities provided.")
parser.add_argument("--train_batch_size", default=32, type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=8, type=int,
help="Total batch size for eval.")
parser.add_argument("--eval_metric", default="f1", type=str)
parser.add_argument("--learning_rate", default=None, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--warmup_proportion", default=0.1, type=float,
help="Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10%% of training.")
parser.add_argument("--no_cuda", action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument('--seed', type=int, default=0,
help="random seed for initialization")
parser.add_argument("--bertadam", action="store_true", help="If bertadam, then set correct_bias = False")
parser.add_argument("--entity_output_dir", type=str, default=None, help="The directory of the prediction files of the entity model")
parser.add_argument("--entity_predictions_dev", type=str, default="ent_pred_dev.json", help="The entity prediction file of the dev set")
parser.add_argument("--entity_predictions_test", type=str, default="ent_pred_test.json", help="The entity prediction file of the test set")
parser.add_argument("--prediction_file", type=str, default="predictions.json", help="The prediction filename for the relation model")
parser.add_argument("--feature_file", type=str, default="feature_default", help="The prediction filename for the relation model")
parser.add_argument('--task', type=str, default=None, required=True, choices=['ace04', 'ace05', 'scierc'])
parser.add_argument('--context_window', type=int, default=0)
parser.add_argument('--add_new_tokens', action='store_true',
help="Whether to add new tokens as marker tokens instead of using [unusedX] tokens.")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
args = parser.parse_args()

platform = 'PyTorch'
executor = 'TripleExtraction'
model = 'TripleExtraction'
print("根据配置,使用 {} 框架,{} 执行器训练 {} 模型。".format(platform, executor, model))
print("-----------------------------------------------")
# 模型训练
executor = OpenKSModel.get_module(platform, executor)
seq2set = executor(dataset=dataset, model=OpenKSModel.get_module(platform, model), args=args)
seq2set.run()

print("-----------------------------------------------")
19 changes: 19 additions & 0 deletions openks/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,25 @@ def save_model(self, *args):
def run(self, *args):
return NotImplemented

class TripleExtractionModel(OpenKSModel):
''' Base class for triple extraction trainer '''
def __init__(self, name: str = 'model-name', args: List = None):
self.name = name

def data_reader(self, *args):
return NotImplemented

def evaluate(self, *args):
return NotImplemented

def load_model(self, *args):
return NotImplemented

def save_model(self, *args):
return NotImplemented

def run(self, *args):
return NotImplemented

class HypernymDiscoveryModel(OpenKSModel):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .matcher import *
from .set_criterion import *
64 changes: 64 additions & 0 deletions openks/models/pytorch/ke_modules/bipartite_modules/matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Modules to compute the matching cost and solve the corresponding LSAP.
"""
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn


class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""

def __init__(self, loss_weight, matcher):
super().__init__()
self.cost_relation = loss_weight["relation"]
self.cost_head = loss_weight["head_entity"]
self.cost_tail = loss_weight["tail_entity"]
self.matcher = matcher

@torch.no_grad()
def forward(self, outputs, targets):
""" Performs the matching

Params:
outputs: This is a dict that contains at least these entries:
"pred_rel_logits": Tensor of dim [batch_size, num_generated_triples, num_classes] with the classification logits
"{head, tail}_{start, end}_logits": Tensor of dim [batch_size, num_generated_triples, seq_len] with the predicted index logits
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_generated_triples, num_gold_triples)
"""
bsz, num_generated_triples = outputs["pred_rel_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
pred_rel = outputs["pred_rel_logits"].flatten(0, 1).softmax(-1) # [bsz * num_generated_triples, num_classes]
gold_rel = torch.cat([v["relation"] for v in targets])
# after masking the pad token
pred_head_start = outputs["head_start_logits"].flatten(0, 1).softmax(-1) # [bsz * num_generated_triples, seq_len]
pred_head_end = outputs["head_end_logits"].flatten(0, 1).softmax(-1)
pred_tail_start = outputs["tail_start_logits"].flatten(0, 1).softmax(-1)
pred_tail_end = outputs["tail_end_logits"].flatten(0, 1).softmax(-1)

gold_head_start = torch.cat([v["head_start_index"] for v in targets])
gold_head_end = torch.cat([v["head_end_index"] for v in targets])
gold_tail_start = torch.cat([v["tail_start_index"] for v in targets])
gold_tail_end = torch.cat([v["tail_end_index"] for v in targets])
if self.matcher == "avg":
cost = - self.cost_relation * pred_rel[:, gold_rel] - self.cost_head * 1/2 * (pred_head_start[:, gold_head_start] + pred_head_end[:, gold_head_end]) - self.cost_tail * 1/2 * (pred_tail_start[:, gold_tail_start] + pred_tail_end[:, gold_tail_end])
elif self.matcher == "min":
cost = torch.cat([pred_head_start[:, gold_head_start].unsqueeze(1), pred_rel[:, gold_rel].unsqueeze(1), pred_head_end[:, gold_head_end].unsqueeze(1), pred_tail_start[:, gold_tail_start].unsqueeze(1), pred_tail_end[:, gold_tail_end].unsqueeze(1)], dim=1)
cost = - torch.min(cost, dim=1)[0]
else:
raise ValueError("Wrong matcher")
cost = cost.view(bsz, num_generated_triples, -1).cpu()
num_gold_triples = [len(v["relation"]) for v in targets]
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost.split(num_gold_triples, -1))]
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

128 changes: 128 additions & 0 deletions openks/models/pytorch/ke_modules/bipartite_modules/set_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch.nn.functional as F
import torch.nn as nn
import torch, math
from .matcher import HungarianMatcher


class SetCriterion(nn.Module):
""" This class computes the loss for Set_RE.
The process happens in two steps:
1) we compute hungarian assignment between ground truth and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class, subject position and object position)
"""
def __init__(self, num_classes, loss_weight, na_coef, losses, matcher):
""" Create the criterion.
Parameters:
num_classes: number of relation categories
matcher: module able to compute a matching between targets and proposals
loss_weight: dict containing as key the names of the losses and as values their relative weight.
na_coef: list containg the relative classification weight applied to the NA category and positional classification weight applied to the [SEP]
losses: list of all the losses to be applied. See get_loss for list of available losses.
"""
super().__init__()
self.num_classes = num_classes
self.loss_weight = loss_weight
self.matcher = HungarianMatcher(loss_weight, matcher)
self.losses = losses
rel_weight = torch.ones(self.num_classes + 1)
rel_weight[-1] = na_coef
self.register_buffer('rel_weight', rel_weight)

def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs, targets)
# Compute all the requested losses
losses = {}
for loss in self.losses:
if loss == "entity" and self.empty_targets(targets):
pass
else:
losses.update(self.get_loss(loss, outputs, targets, indices))
losses = sum(losses[k] * self.loss_weight[k] for k in losses.keys() if k in self.loss_weight)
return losses

def relation_loss(self, outputs, targets, indices):
"""Classification loss (NLL)
targets dicts must contain the key "relation" containing a tensor of dim [bsz]
"""
src_logits = outputs['pred_rel_logits'] # [bsz, num_generated_triples, num_rel+1]
idx = self._get_src_permutation_idx(indices)
target_classes_o = torch.cat([t["relation"][i] for t, (_, i) in zip(targets, indices)])
target_classes = torch.full(src_logits.shape[:2], self.num_classes,
dtype=torch.int64, device=src_logits.device)
target_classes[idx] = target_classes_o
loss = F.cross_entropy(src_logits.flatten(0, 1), target_classes.flatten(0, 1), weight=self.rel_weight)
losses = {'relation': loss}
return losses

@torch.no_grad()
def loss_cardinality(self, outputs, targets, indices):
""" Compute the cardinality error, ie the absolute error in the number of predicted non-empty triples
This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
"""
pred_rel_logits = outputs['pred_rel_logits']
device = pred_rel_logits.device
tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
# Count the number of predictions that are NOT "no-object" (which is the last class)
card_pred = (pred_rel_logits.argmax(-1) != pred_rel_logits.shape[-1] - 1).sum(1)
card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
losses = {'cardinality_error': card_err}
return losses

def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx

def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx

def get_loss(self, loss, outputs, targets, indices, **kwargs):
loss_map = {
'relation': self.relation_loss,
'cardinality': self.loss_cardinality,
'entity': self.entity_loss
}
return loss_map[loss](outputs, targets, indices, **kwargs)

def entity_loss(self, outputs, targets, indices):
"""Compute the losses related to the position of head entity or tail entity
"""
idx = self._get_src_permutation_idx(indices)
selected_pred_head_start = outputs["head_start_logits"][idx]
selected_pred_head_end = outputs["head_end_logits"][idx]
selected_pred_tail_start = outputs["tail_start_logits"][idx]
selected_pred_tail_end = outputs["tail_end_logits"][idx]

target_head_start = torch.cat([t["head_start_index"][i] for t, (_, i) in zip(targets, indices)])
target_head_end = torch.cat([t["head_end_index"][i] for t, (_, i) in zip(targets, indices)])
target_tail_start = torch.cat([t["tail_start_index"][i] for t, (_, i) in zip(targets, indices)])
target_tail_end = torch.cat([t["tail_end_index"][i] for t, (_, i) in zip(targets, indices)])


head_start_loss = F.cross_entropy(selected_pred_head_start, target_head_start)
head_end_loss = F.cross_entropy(selected_pred_head_end, target_head_end)
tail_start_loss = F.cross_entropy(selected_pred_tail_start, target_tail_start)
tail_end_loss = F.cross_entropy(selected_pred_tail_end, target_tail_end)
losses = {'head_entity': 1/2*(head_start_loss + head_end_loss), "tail_entity": 1/2*(tail_start_loss + tail_end_loss)}
# print(losses)
return losses

@staticmethod
def empty_targets(targets):
flag = True
for target in targets:
if len(target["relation"]) != 0:
flag = False
break
return flag
Loading