Skip to content

add eval detail of echo relation #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

11 changes: 7 additions & 4 deletions experiments/chinese_selection_re.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"dev": "dev_data.json",
"relation_vocab": "relation_vocab.json",
"print_epoch": 3,
"evaluation_epoch":27,
"evaluation_epoch": 21,
"resume_model": 0,
"max_text_len": 100,
"cell_name": "lstm",
"emb_size": 300,
Expand All @@ -17,8 +18,10 @@
"threshold": 0.5,
"activation": "tanh",
"optimizer": "adam",
"epoch_num": 30,
"lr": 1e-3,
"epoch_num": 100,
"train_batch": 100,
"eval_batch": 400,
"gpu":1
}
"patient": 5,
"gpu": 1
}
9 changes: 6 additions & 3 deletions experiments/conll_selection_re.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"relation_vocab": "relation_vocab.json",
"print_epoch": 3,
"evaluation_epoch":27,
"resume_model": 0,
"max_text_len": 348,
"cell_name": "lstm",
"emb_size": 300,
Expand All @@ -17,8 +18,10 @@
"threshold": 0.5,
"activation": "tanh",
"optimizer": "adam",
"epoch_num": 150,
"lr": 1e-3,
"epoch_num": 30,
"train_batch": 16,
"eval_batch": 32,
"gpu":3
}
"patient": 5,
"gpu": 0
}
28 changes: 28 additions & 0 deletions experiments/medical_selection_re.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"dataset": "medical",
"model": "multi_head_selection",
"data_root": "data/medical/multi_head_selection",
"raw_data_root": "raw_data/medical_re",
"train": "train.txt",
"dev": "dev.txt",
"test": "test.txt",
"relation_vocab": "relation_vocab.json",
"print_epoch": 3,
"evaluation_epoch":51,
"resume_model": 0,
"max_text_len": 400,
"cell_name": "lstm",
"emb_size": 200,
"rel_emb_size": 100,
"bio_emb_size": 50,
"hidden_size": 200,
"threshold": 0.5,
"activation": "tanh",
"optimizer": "adam",
"lr": 1e-3,
"epoch_num": 100,
"train_batch": 20,
"eval_batch": 20,
"patient":5,
"gpu":0
}
Empty file added lib/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions lib/config/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, path: str):
self.bio_emb_size: int
self.train_batch: int
self.eval_batch: int
self.lr: float
self.resume_model: int
self.patient: int

self.__dict__ = json.load(open(path, 'r'))

Expand Down
41 changes: 37 additions & 4 deletions lib/dataloaders/selection_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@


class Selection_Dataset(Dataset):
def __init__(self, hyper, dataset):
def __init__(self, hyper, dataset, type='train|eval'):
self.hyper = hyper
self.data_root = hyper.data_root
self.type = type

self.word_vocab = json.load(
open(os.path.join(self.data_root, 'word_vocab.json'), 'r'))
Expand All @@ -30,6 +31,7 @@ def __init__(self, hyper, dataset):
self.text_list = []
self.bio_list = []
self.spo_list = []
self.jobid_list = []

# for bert only
self.bert_tokenizer = BertTokenizer.from_pretrained(
Expand All @@ -38,17 +40,25 @@ def __init__(self, hyper, dataset):
for line in open(os.path.join(self.data_root, dataset), 'r'):
line = line.strip("\n")
instance = json.loads(line)
self.text_list.append(instance['text'])
if self.type == 'predict':
continue

self.selection_list.append(instance['selection'])
self.text_list.append(instance['text'])
self.bio_list.append(instance['bio'])
self.spo_list.append(instance['spo_list'])
self.jobid_list.append(instance.get('jobid', 0))

def __getitem__(self, index):
selection = self.selection_list[index]
text = self.text_list[index]
if self.type == 'predict':
tokens_id = self.text2tensor(text)
return tokens_id, len(text), text

selection = self.selection_list[index]
bio = self.bio_list[index]
spo = self.spo_list[index]
jobid = self.jobid_list[index]
if self.hyper.cell_name == 'bert':
text, bio, selection = self.pad_bert(text, bio, selection)
tokens_id = torch.tensor(
Expand All @@ -58,7 +68,7 @@ def __getitem__(self, index):
bio_id = self.bio2tensor(bio)
selection_id = self.selection2tensor(text, selection)

return tokens_id, bio_id, selection_id, len(text), spo, text, bio
return tokens_id, bio_id, selection_id, len(text), spo, text, bio, jobid

def __len__(self):
return len(self.text_list)
Expand Down Expand Up @@ -122,6 +132,7 @@ def __init__(self, data):
self.spo_gold = transposed_data[4]
self.text = transposed_data[5]
self.bio = transposed_data[6]
self.jobid = transposed_data[7]

def pin_memory(self):
self.tokens_id = self.tokens_id.pin_memory()
Expand All @@ -135,3 +146,25 @@ def collate_fn(batch):


Selection_loader = partial(DataLoader, collate_fn=collate_fn, pin_memory=True)


# for predict ...
class Batch_reader_predict(object):
def __init__(self, data):
transposed_data = list(zip(*data))
# tokens_id, len(text), text

self.tokens_id = pad_sequence(transposed_data[0], batch_first=True)
self.length = transposed_data[1]
self.text = transposed_data[2]

def pin_memory(self):
self.tokens_id = self.tokens_id.pin_memory()
return self


def collate_fn_predict(batch):
return Batch_reader_predict(batch)


Selection_loader_predict = partial(DataLoader, collate_fn=collate_fn_predict, pin_memory=True)
75 changes: 62 additions & 13 deletions lib/metrics/F1_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,53 @@
from abc import ABC, abstractmethod
from overrides import overrides

from collections import namedtuple
from collections import defaultdict

class F1_abc(object):
def __init__(self):
self.A = 1e-10
self.B = 1e-10
self.C = 1e-10
self.A = 0
self.B = 0
self.C = 0
self.ABC = namedtuple('ABC', ['A', 'B', 'C'])
self.rel_detail = defaultdict(list)

def reset(self) -> None:
self.A = 1e-10
self.B = 1e-10
self.C = 1e-10
self.A = 0
self.B = 0
self.C = 0
self.rel_detail.clear()

def get_metric(self, reset: bool = False):
if reset:
self.reset()

f1, p, r = 2 * self.A / (self.B +
self.C), self.A / self.B, self.A / self.C
result = {"precision": p, "recall": r, "fscore": f1}

result = self.calc(self.A, self.B, self.C)
return result

@staticmethod
def calc(A, B, C):
p = A / B if B > 0 else 0.
r = A / C if C > 0 else 0.
f1 = 2 * p * r / (p + r) if (p+r) > 0 else 0.
prf1 = {"precision": p, "recall": r, "fscore": f1}
return prf1

@staticmethod
def calc_abc(A, B, C):
p = A / B if B > 0 else 0.
r = A / C if C > 0 else 0.
f1 = 2 * p * r / (p + r) if (p+r) > 0 else 0.
prf1 = {"precision": p, "recall": r, "fscore": f1, "ABC": "{}:{}:{}".format(A, B, C), "sum":A+B+C}
return prf1

def get_metric_detail(self, reset: bool = False):
if reset:
self.reset()
results = {}
for k, v in self.rel_detail.items():
results[k] = self.calc_abc(v[0], v[1], v[2])
return results

def __call__(self, predictions,
gold_labels):
raise NotImplementedError
Expand All @@ -50,6 +75,30 @@ def __call__(self, predictions: List[List[Dict[str, str]]],
self.B += len(p_set)
self.C += len(g_set)

# for rel detail
g_set_rel, p_set_rel = defaultdict(list), defaultdict(list)
try:
for gg in g:
g_set_rel[gg['predicate']].append('_'.join((gg['object'], gg['predicate'], gg['subject'])))
for pp in p:
p_set_rel[pp['predicate']].append('_'.join((pp['object'], pp['predicate'], pp['subject'])))
except:
for gg in g:
g_set_rel[gg['predicate']].append('_'.join((''.join(gg['object']), gg['predicate'], ''.join(gg['subject']))))
for pp in p:
p_set_rel[gg['predicate']].append('_'.join((''.join(pp['object']), pp['predicate'], ''.join(pp['subject']))))

rels = set(list(g_set_rel.keys()) + list(p_set_rel.keys()))
for k in rels:
if k not in self.rel_detail:
self.rel_detail[k] = [0, 0, 0]

for k in rels:
vg, vp = g_set_rel.get(k, []), p_set_rel.get(k, [])
self.rel_detail[k][0] += len(set(vg) & set(vp))
self.rel_detail[k][1] += len(set(vp))
self.rel_detail[k][2] += len(set(vg))


class F1_ner(F1_abc):

Expand All @@ -63,5 +112,5 @@ def __call__(self, predictions: List[List[str]], gold_labels: List[List[str]]):
bi_p = sum(tok_p in ('B', 'I') for tok_p in p)

self.A += inter
self.B += bi_g
self.C += bi_p
self.B += bi_p
self.C += bi_g
Loading