From 6b95691c041f330dbc97b70ff756710df02e7c5b Mon Sep 17 00:00:00 2001 From: acproject Date: Thu, 26 Nov 2020 17:01:42 +0800 Subject: [PATCH] update --- .gitignore | 2 + GNN/Model/MultiHeadAttention.py | 49 +- GNN/Model/TreeLSTM.py | 266 ++++ NAACL/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 118 bytes NAACL/__pycache__/settings.cpython-36.pyc | Bin 0 -> 135 bytes NAACL/__pycache__/util.cpython-36.pyc | Bin 0 -> 1066 bytes NAACL/__pycache__/vocabulary.cpython-36.pyc | Bin 0 -> 4210 bytes NAACL/backoffnet.py | 1208 ++++++++++++++----- NAACL/ensemble.py | 172 +-- NAACL/out/log.txt | 1 + NAACL/prune_pred_gv_map.py | 100 +- NAACL/settings.py | 2 +- NAACL/util.py | 80 +- NAACL/vocabulary.py | 184 +-- maths.md | 15 + run.sh | 36 + 16 files changed, 1537 insertions(+), 578 deletions(-) create mode 100644 GNN/Model/TreeLSTM.py create mode 100644 NAACL/__pycache__/__init__.cpython-36.pyc create mode 100644 NAACL/__pycache__/settings.cpython-36.pyc create mode 100644 NAACL/__pycache__/util.cpython-36.pyc create mode 100644 NAACL/__pycache__/vocabulary.cpython-36.pyc create mode 100644 NAACL/out/log.txt create mode 100644 maths.md create mode 100644 run.sh diff --git a/.gitignore b/.gitignore index 538ddf4..77295d2 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,5 @@ yarn-error.log* src/components/jvm_dll_path.json /saved_models /saved_models/** +/data +/data/** diff --git a/GNN/Model/MultiHeadAttention.py b/GNN/Model/MultiHeadAttention.py index a7af543..93f1c65 100644 --- a/GNN/Model/MultiHeadAttention.py +++ b/GNN/Model/MultiHeadAttention.py @@ -1,17 +1,32 @@ -import torch -import torch.nn as nn - - - -class MultiHeadAttention(nn.Module): - def __init__(self, h, dim_model): - ''' - - :param h: number of heads - :param dim_model: hidden dimension - ''' - super(MultiHeadAttention, self).__init__() - self.d_k = dim_model // h - self.h = h - # W_q, W_k, W_v, W_o - self.linears = clones(nn.Linear(dim_model, dim_model), 4) +import torch +import torch.nn as nn +import torch as th +import numpy as np +import networkx as nx +import matplotlib.pyplot as plt + +from modules.layers import * +from modules.functions import * +from modules.embedding import * +from modules.viz import att_animation, get_attention_map +from optims import NoamOpt +from loss import LabelSmoothing, SimpleLossCompute +from dataset import get_dataset, GraphPool + +import dgl.function as fn +import torch.nn.init as INIT +from torch.nn import LayerNorm + + +class MultiHeadAttention(nn.Module): + def __init__(self, h, dim_model): + ''' + + :param h: number of heads + :param dim_model: hidden dimension + ''' + super(MultiHeadAttention, self).__init__() + self.d_k = dim_model // h + self.h = h + # W_q, W_k, W_v, W_o + self.linears = clones(nn.Linear(dim_model, dim_model), 4) diff --git a/GNN/Model/TreeLSTM.py b/GNN/Model/TreeLSTM.py new file mode 100644 index 0000000..36f7f00 --- /dev/null +++ b/GNN/Model/TreeLSTM.py @@ -0,0 +1,266 @@ +from collections import namedtuple + +import dgl +from dgl.data.tree import SSTDataset + +SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) + +trainset = SSTDataset(mode='tiny') +tiny_sst = trainset.trees +num_vocabs = trainset.num_vocabs +num_classes = trainset.num_classes + +vocab = trainset.vocab # vocabulary dict: key -> id +inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word + +a_tree = tiny_sst[0] +for token in a_tree.ndata['x'].tolist(): + if token != trainset.PAD_WORD: + print(inv_vocab[token], end=" ") + +############################################################################## +# Step 1: Batching +# ---------------- +# +# Add all the trees to one graph, using +# the :func:`~dgl.batched_graph.batch` API. +# +import networkx as nx +import matplotlib.pyplot as plt + +graph = dgl.batch(tiny_sst) +def plot_tree(g): + # this plot requires pygraphviz package + pos = nx.nx_agraph.graphviz_layout(g, prog='dot') + nx.draw(g, pos, with_labels=False, node_size=10, + node_color=[[.5, .5, .5]], arrowsize=4) + # plt.show() +# plot_tree(graph.to_networkx()) + +# Step 2: Tree-LSTM cell with message-passing APIs +# ------------------------------------------------ +# +# Researchers have proposed two types of Tree-LSTMs: Child-Sum +# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial you focus +# on applying *Binary* Tree-LSTM to binarized constituency trees. This +# application is also known as *Constituency Tree-LSTM*. Use PyTorch +# as a backend framework to set up the network. +# +# In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden +# representation :math:`h_j` and a memory cell :math:`c_j`. The unit +# :math:`j` takes the input vector :math:`x_j` and the hidden +# representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as +# input, then update its new hidden representation :math:`h_j` and memory +# cell :math:`c_j` by: +# +# .. math:: +# +# i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ +# f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ +# o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ +# u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ +# c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ +# h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\ +# +# It can be decomposed into three phases: ``message_func``, +# ``reduce_func`` and ``apply_node_func``. +# +# .. note:: +# ``apply_node_func`` is a new node UDF that has not been introduced before. In +# ``apply_node_func``, a user specifies what to do with node features, +# without considering edge features and messages. In a Tree-LSTM case, +# ``apply_node_func`` is a must, since there exists (leaf) nodes with +# :math:`0` incoming edges, which would not be updated with +# ``reduce_func``. +# +import torch as th +import torch.nn as nn +class TreeLSTMCell(nn.Module): + def __init__(self, x_size, h_size): + super(TreeLSTMCell, self).__init__() + self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False) + self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False) + self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size)) + self.U_f = nn.Linear(2 * h_size, 2 * h_size) + + def message_func(self, edges): + return {'h': edges.src['h'], 'c': edges.src['c']} + + def reduce_func(self, nodes): + # concatenate h_j1 for equation (1), (2), (3), (4) + h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) + # equation (2) + f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) + # second term of equation (5) + c = th.sum(f * nodes.mailbox['c'], 1) + return {'iou': self.U_iou(h_cat), 'c': c} + + def apply_node_func(self, nodes): + # equation (1), (3), (4) + iou = nodes.data['iou'] + self.b_iou + i, o, u = th.chunk(iou, 3, 1) + i, o, u = th.sigmoid(i) ,th.sigmoid(o), th.tanh(u) + + # equation (5) + c = i * u + nodes.data['c'] + # equation (6) + h = o * th.tanh(c) + return {'h': h, 'c': c} + + +############################################################################## +# Step 3: Define traversal +# ------------------------ +# +# After you define the message-passing functions, induce the +# right order to trigger them. This is a significant departure from models +# such as GCN, where all nodes are pulling messages from upstream ones +# *simultaneously*. +# +# In the case of Tree-LSTM, messages start from leaves of the tree, and +# propagate/processed upwards until they reach the roots. A visualization +# is as follows: +# +# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif +# :alt: +# +# DGL defines a generator to perform the topological sort, each item is a +# tensor recording the nodes from bottom level to the roots. One can +# appreciate the degree of parallelism by inspecting the difference of the +# followings: +# +# to heterogenous graph +trv_a_tree = dgl.graph(a_tree.edges()) +print('Traversing one tree:') +print(dgl.topological_nodes_generator(trv_a_tree)) + +# to heterogenous graph +trv_graph = dgl.graph(graph.edges()) +print('Traversing many trees at the same time:') +print(dgl.topological_nodes_generator(trv_graph)) + +############################################################################## +# Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: + +import dgl.function as fn +import torch as th + +trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) +traversal_order = dgl.topological_nodes_generator(trv_graph) +trv_graph.prop_nodes(traversal_order, + message_func=fn.copy_src('a', 'a'), + reduce_func=fn.sum('a', 'a')) + +class TreeLSTM(nn.Module): + def __init__(self, + num_vocabs, + x_size, + h_size, + num_classes, + dropout, + pretrained_emb=None): + super(TreeLSTM, self).__init__() + self.x_size = x_size + self.embedding = nn.Embedding(num_vocabs, x_size) + if pretrained_emb is not None: + print('Using glove') + self.embedding.weight.data.copy_(pretrained_emb) + self.embedding.weight.requires_grad = True + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(h_size, num_classes) + self.cell = TreeLSTMCell(x_size, h_size) + + def forward(self, batch, h, c): + ''' + Compute tree-lstm prediction given a batch. + :param batch: dgl.data.SSTBatch + :param h: Tensor initial hidden state + :param c: Tensor initial cell state + :return: logits : Tensor + The prediction of each node. + ''' + g = batch.graph + # to heterogenous graph + g = dgl.graph(g.edges()) + # feed embedding + embeds = self.embedding(batch.wordid * batch.mask) + g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) + g.ndata['h'] = h + g.ndata['c'] = c + # propagate + dgl.prop_nodes_topo(g, + message_func=self.cell.message_func, + reduce_func=self.cell.reduce_func, + apply_node_func=self.cell.apply_node_func) + # compute logits + h = self.dropout(g.ndata.pop('h')) + logits = self.linear(h) + return logits + +############################################################################## +# Main Loop +# --------- +# +# Finally, you could write a training paradigm in PyTorch. +# + +from torch.utils.data import DataLoader +import torch.nn.functional as F + +device = th.device('cpu') + +# hyper parameters +x_size = 256 +h_size = 256 +dropout = 0.5 +lr = 0.05 +weight_decay = 1e-4 +epochs = 10 + +# create the model +model = TreeLSTM(trainset.num_vocabs, + x_size, + h_size, + trainset.num_classes, + dropout) + +print(model) + +# create the optimizer +optimizer = th.optim.Adagrad(model.parameters(), + lr=lr, + weight_decay=weight_decay) + +def batcher(dev): + def batcher_dev(batch): + batch_trees = dgl.batch(batch) + return SSTBatch(graph=batch_trees, + mask=batch_trees.ndata['mask'].to(device), + wordid=batch_trees.ndata['x'].to(device), + label=batch_trees.ndata['y'].to(device)) + + return batcher_dev + +train_loader = DataLoader(dataset=tiny_sst, + batch_size=5, + collate_fn=batcher(device), + shuffle=False, + num_workers=0) + +# training loop +for epoch in range(epochs): + for step, batch in enumerate(train_loader): + g = batch.graph + n = g.number_of_nodes() + h = th.zeros((n, h_size)) + c = th.zeros((n, h_size)) + logits = model(batch, h, c) + logp =F.log_softmax(logits, 1) + loss = F.nll_loss(logp, batch.label, reduction='sum') + optimizer.zero_grad() + loss.backward() + optimizer.step() + pred = th.argmax(logits, 1) + acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label) + print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format( + epoch, step, loss.item(), acc)) \ No newline at end of file diff --git a/NAACL/__pycache__/__init__.cpython-36.pyc b/NAACL/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8aadd899f142046eaffa5258b2954125e22a103 GIT binary patch literal 118 zcmXr!<>k7ub#FWa5IhDEFu(|8H~?`m3y?@*2xib^^jpbL1QJFNzm%P=V#@Q2vWp86 tlT%~d{rrk!{2U#fePZI{GxIV_;^XxSDsOSvmUnZEw6V0|UcjAcg}*Aj<)Wi#dQq3PTh_3S%&XCetmJl*E!mKTXD494?L_ zj`1#@K`R-Gn1O1*#4lxMtC;fqqU_>=#N^Z%cR#=47(YixXP=nj)RL0Sy!2wdg34PQ SHo5sJr8%i~AWMpYm;nHq10baU literal 0 HcmV?d00001 diff --git a/NAACL/__pycache__/util.cpython-36.pyc b/NAACL/__pycache__/util.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a59b88e8ba890b9db7ffa3b50135d4257d5210fb GIT binary patch literal 1066 zcmY*XO>fgc5Zw>kaZ6K@Dk_Ld;7blcRYghU)=ijyZy%?!Brbf3G^tAk#moc~{F*emu!SkCYuYqKu0bjb(;YNSAH#JioE})4n6Mg5{eTNe zOs;51pO6WSNeomXugGM^Gw=*N3)wxMgXiFR$TL31SuPCuS+JPKEao0xB@OhY>{r1x zm+f_%;c4cZFz$Gkb`W7+nDRzgG10NeWpl-@PjOrl<;y~pUu=pq?S?wckPcO3m&Gte zQ~k%0S>C96N>7`fsu&6H`(d)wIH)z+#|J0v*ZZeQ{w{j+;p53%^8MR7It=A@Rbq(F zrAo}RU@Q>~`@Z@~WF9;Gdsf?SeF@}+8g;sE>!@B=t@{3c?Okge_6HlI%jAJxpdAKn z6-s^c>M0zdOLL!1m#IxxU-31={+L`5bYl30>}E55 zjZ6%!*looYD$KbG@Tf3jBQ|dF9!R+FHUi|2l literal 0 HcmV?d00001 diff --git a/NAACL/__pycache__/vocabulary.cpython-36.pyc b/NAACL/__pycache__/vocabulary.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1dc65df1a6d653bde8862570d27d05383e0690f GIT binary patch literal 4210 zcmb_fTW{M&7M|fnq7}<_94B$yv}MyIwRTe@Xt&+ntk;d<6v1u{^^)xcMhb$_jAP0a zr3~rB7UU-dXkQjZ|3hE;BibLr*FNn_RwDc zwEe+Ghi>wpiOFqT@pl*`%Qg{j2t$}p`Hm@8g)M5BS)vZgKpimyYKw+A16mWaq6u0T zb7CIU5ewoh=!{qt=Rg}`Nt_2gBQA)GptE9GTmo&174Z`2oVYAr2AvnLh*v=ugezVH zJu9w=*FhKMIdN57d&>7#QeBl&&pY@pTL%K4KnA zc-I&+%*@0}?0}xt66+z+y2i&C`i+J=`P~}$%I~=c(Y}}LDH-qeLvhP}(pMsOp9IOC zyZ&!C*$ZN~=SN3wr+*M7v=D~w1Bu%x4un)vxEkxVV%v7hNt=~T(=(lZ7|Ko(;H9*_ zMLS8A*0L?r8TiWk--jQ(-_EB0Y2SJOi?l(CycI-39;P!i_rf4fT1INcGTc?<&hMF; z1$lOD>(+g6?8gJYBk$jBx8wWm&CRXt`;QCw)&@tZ<9R_8B%W6z)Eo@VuJH!`!*eG< zujO1nYs5(o{U{dc(ujYA$%u7TY6HA17)RCuTrn6*SzW5q;j&)k|;l)xE zmU#-Dem@Fnv%FC}!DDA87Y|J0Q~D`__bRECV=`QRfnsr?8Cg)I$1-j*v@im%izwRv zHxT9~5VM9mMHFr#%*iu15Ze3FjSrNhutv#=`U%DPZ$M8y5SRp6H7ce{oYRF{%wJxb+{;l!Ix|bmZ zU%KEx**Dg1s*5z`I=wO+sj56{tkcQR>GuX3|HQp^UX2ZJmNeI<1-nN@+O4gJI2*!v z1C#t^MFwE|OGL6HRx6lmQMp)M$Id>t-#>D}j1MpWn&9cskx$CdO+YjMAZUuJ%XHM# z6)dgiS_7^d)3hdRrb0#M?gXRMmZ9urRiuJavp|Wq`ByFr2$kzeg{guL)e{OX9asY4 zq!FTi4Z%eSO#(~%trA`ImJVc{9_-77li2!<;hZo|PKj5+3PT8}ln1%~WS-K&HM9Ou zM)Gi=Hr~TC9a|O~zF7iG$xi>VRC#EX;L6`@@zm<}gD3~A4x59HoC4YWsm(TWnE-9& z#WE|fRm;(&He}U+>uQk4_OwFa?D?@SN1G+d3iHD!*A+=-NkZMrLxkXcW#xD;W&s-IB=p62`Pv5k=y*h?%M>mD+Jp<=@1U zO6;7-sCyndCeQmXy-C`_m5U`YljVw}Uo$GT0yNLM1ABo)>T|*3dD|tCWgq72g%2sy zNMmi?yO;n->T`^5t{fYOD+DTft%YMVG0?yo*=hc0>2rpD%lw#q!oJAQ_W67MW0^Is zx^_&8PipN3p=`Iz)as*W%=$-rpn@pL+ee2uBEKh7zN_CleNtL6w*HVz(#MlA%V+tp z;xMciCJ_&vz1pPPn2g}AV}lGE^KVv<%|mx&!l>nAYh-0j-LX9~y2i-Xt>gu5oRJwf zGLS#!@cMS&&l(n&jB{71|JJQaNDJBLOlvLW4Opt)f>5-Ci3$el70g?;)a-<@`kQ|K zm_AL^J2bb#Ad=HfrhPvJPfXrnX7;tx;LBCRiHka33Q!Q;1D&VO9fH&I(uU{t`r;s@ zIkN51fgffoY7TGc-c`LsGwS@bu8PW{K6%yaG`mK_UugI%4L4|@m{)5wQ0t;7Z1ABM z;!~t2lfAxBbh?zBNG9l-swli|T>87-;f+S4=~#}5%fV>lI{6!z*Yq1LTTz^y_5HdP!9By&T_yzvsbZ7to literal 0 HcmV?d00001 diff --git a/NAACL/backoffnet.py b/NAACL/backoffnet.py index 09a2738..8eb1650 100644 --- a/NAACL/backoffnet.py +++ b/NAACL/backoffnet.py @@ -1,292 +1,916 @@ -import argparse -import collections -import glob -import json -import math -import numpy as np -import random -from ordered_set import OrderedSet -import os -import pickle -import shutil -from sklearn.metrics import average_precision_score -import sys -import termcolor -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.nn.utils.rnn as rnn -import torch.optim as optim -from tqdm import tqdm - -from NAACL import vocabulary -from NAACL import settings -from NAACL import util - -WORD_VEC_FILE = 'wordvec/PubMed-and-PMC-w2v.txt' -WORD_VEC_NUM_LINES = 4087447 -EMB_SZIE = 200 # size of word embeddings -PARA_EMB_SIZE = 100 # size of paragraph index embeddings -PARA_EMB_MAX_SPAN = 1000 -MAX_ENTITIES_PER_TYPE = 200 -MAX_NUM_PARAGRAPHS = 200 -MAX_NUM_CANDIDATES = 10000 -ALL_ENTITY_TYPES = ['drug', 'gene', 'variant'] -ALL_ENTITY_TYPES_PAIRS = [('drug', 'gene'), ('drug', 'variant'), ('gene', 'variant')] -MAX_PARAGRAPH_LENGTH = 800 -CLIP_THRESH = 5 # Gradient clipping (on L2 norm) -JAX_DEV_PMIDS_FILE = 'jax/jax_dev_pmids.txt' -JAX_TEST_PMIDS_FILE = 'jax/jax_test_pmids.txt' - -log_file = None - -def log(msg): - print(msg, file=sys.stderr) - if log_file: - print(msg, file=log_file) - -ParaMention = collections.namedtuple( - 'ParaMention',['start', 'end', 'type', 'name']) - -class Candidate(object): - def __init__(self, drug=None, gene=None, variant=None, label=None): - self.drug = drug - self.gene = gene - self.variant = variant - self.label = label - - def remove_entity(self, i, new_label=None): - ''' - :param i: - :param new_label: - :return: Return new Candidate with entity |i| replaced with None. - ''' - triple = (self.drug, self.gene, self.variant) - new_triple = triple[:i] + (None,) + triple[i+1:] - return Candidate(*new_triple, label=new_label) - - def get_entities(self): - return (self.drug, self.gene, self.variant) - - def is_triple(self): - return self.drug and self.gene and self.variant - - def get_types(self): - out = [] - if self.drug: - out.append('drug') - if self.gene: - out.append('gene') - if self.variant: - out.append('variant') - return tuple(out) - - def __key(self): - return (self.drug, self.gene, self.variant, self.label) - - def __eq__(x, y): - return x.__key() == y.__key() - - def __hash__(self): - return hash(self.__key()) - - -class Example(object): - def __init__(self, pmid, paragraphs, mentions, triple_candidates, pair_candidates): - self.pmid = pmid - self.paragraphs = paragraphs - self.mentions = mentions - self.triple_candidates = triple_candidates - self.pair_candidates = pair_candidates - self.entities = collections.defaultdict(OrderedSet) - for m_list in mentions: - for m in m_list: - self.entities[m.type].add(m.name) - - @classmethod - def read_examples(cls, example_json_file): - results = [] - with open(os.path.join(settings.DATA_DIR, example_json_file)) as f: - for line in f: - ex = cls.read_examples(line) - results.append(ex) - - return results - - @classmethod - def read_examples(cls, example_json_str): - example_json = json.loads(example_json_str) - mentions = [[ParaMention(**mention) for mention in paragraph_mentions] - for paragraph_mentions in example_json['mentions']] - pair_candidates = {} - - for pair_key in example_json['pair_candidates']: - pair_key_tuple = tuple(json.loads(pair_key)) - pair_candidates[pair_key_tuple] = OrderedSet(Candidate(**x) - for x in example_json['pair_candidates'][pair_key]) - triple_candidates = {} - triple_candidates = [Candidate(**x) - for x in example_json['triple_candidates']] - - return cls(example_json['pmid'], - example_json['paragraphs'], - mentions, - triple_candidates, - pair_candidates) - -class Preprocessor(object): - - def __init__(self, entity_lists, vacab, device): - self.entity_lists = entity_lists - self.vocab = vacab - self.device = device - - def count_labels(self, ex, pair_only=None): - if pair_only: - candidates = ex.pair_candidates[pair_only] - else: - candidates = ex.triple_candidates - - num_pos = sum(c.label for c in candidates) - num_neg = sum(1 - c.label for c in candidates) - return num_neg, num_pos - - def shuffle_entities(self, ex): - entity_map = {} - for e_type in ex.entities: - cur_ents = ex.entities[e_type] - replacements = random.sample(self.entity_lists[e_type], len(cur_ents)) - for e_old, e_new in zip(cur_ents, replacements): - entity_map[(e_type, e_old)] = e_new - - new_paras = [] - new_mentions = [] - for p, m_list in zip(ex.paragraphs, ex.mentions): - new_para = [] - new_m_list =[] - mentions_at_loc = collections.defaultdict(list) - in_mention = [False] * len(p) - for m in m_list: - mentions_at_loc[m.start].append((m.type, m.name)) - for i in range(m.start, m.end): - in_mention[i] = True - for i in range(len(p)): - if mentions_at_loc[i]: - for e_type, name in mentions_at_loc[i]: - e_new = entity_map[(e_type, name)] - m = ParaMention(len(new_para), len(new_para)+1, e_type, name) - new_m_list.append(m) - new_para.append(e_new) - if not in_mention[i]: - new_paras.append(p[i]) - new_paras.append(new_para) - new_mentions.append(new_m_list) - return new_paras, new_mentions - - def preprocess(self, ex, pair_only): - new_paras, new_mentions = self.shuffle_entities(ex) - para_prep = [] - for para_idx, (para, m_list) in enumerate(zip(new_paras, new_mentions)): - word_idxs = torch.tensor(self.vocab.indexify_list(para), - dtype=torch.long, device=self.device) - - para_from_start = [ - para_idx / math.pow(PARA_EMB_MAX_SPAN, 2*i / (PARA_EMB_SIZE // 4)) - for i in range(PARA_EMB_SIZE // 4) - ] - - para_from_end = [ - (len(new_paras)- para_idx) / math.pow(PARA_EMB_MAX_SPAN, 2*i / (PARA_EMB_SIZE // 4)) - for i in range(PARA_EMB_SIZE // 4) - ] - - para_args = torch.cat([torch.tensor(x, dtype=torch.float, device=self.device) - for x in (para_from_start, para_from_end)]) - - para_vec = torch.cat([torch.sin(para_args), torch.cos(para_args)]) - para_prep.append((word_idxs ,para_vec, m_list)) - - # sort for pack_padded_sequence - para_prep.sort(key=lambda x:len(x[0]), reverse=True) - T, P = len(para_prep[0][0]), len(para_prep) - para_mat = torch.zeros((T, P), device=self.device, dtype=torch.long) - for i, x in enumerate(para_prep): - cur_words = x[0] - para_mat[:len(cur_words), i] = cur_words - - lenghts = torch.tensor([len(x[0]) for x in para_prep], device=self.device) - triple_labels = torch.tensor([c.label for c in ex.triple_candidates], - dtype=torch.float, device=self.device) - pair_labels = {k: torch.tensor([c.label for c in ex.pair_candidates[k]], - dtype=torch.float, device=self.device) - for k in ex.pair_candidates} - para_vecs = torch.stack([x[1] for x in para_prep], dim=0) - unlabeled_triple_cands = [Candidate(ex.drug, ex.gene, ex.variant) - for ex in ex.triple_candidates] - unlabeled_pair_cands = {k: [Candidate(ex.drug, ex.gene, ex.variant) - for ex in ex.pair_candidates[k]] - for k in ex.pair_candidates} - return (para_mat, lenghts, para_vecs, [x[2] for x in para_prep], - unlabeled_triple_cands, unlabeled_pair_cands, triple_labels, pair_labels) - -def logsumexp(inputs, dim=None, keepdim=False): - ''' - - :param inputs: A variable with any shape. - :param dim: An integer. - :param keepdim: A boolean. - :return: Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). - ''' - if dim is None: - inputs = inputs.view(-1) - dim = 0 - s, _ = torch.max(inputs, dim=dim, keepdim=True) - outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() - if not keepdim: - outputs = outputs.squeeze(dim) - return outputs - -class BackoffModel(nn.Module): - ''' - Combine triple and pairwise information. - ''' - - def __init__(self, emb_mat, lstm_size, lstm_layers, device, use_lstm=True, - use_position=True, pool_method='max', dropout_prob=0.5, vocab=None, - pair_only=None): - - super(BackoffModel, self).__init__() - self.device = device - self.use_lstm = use_lstm - self.use_position = use_position - self.pool_method - pool_method - self.embs = nn.Embedding.from_pretrained(emb_mat, freeze=False) - self.vocab = vocab - self.pair_only =pair_only - self.dropout = nn.Dropout(p=dropout_prob) - para_emb_size = PARA_EMB_SIZE if use_position else 0 - if use_lstm: - self.lstm_layers = lstm_layers - self.lstm = nn.LSTM(EMB_SZIE + para_emb_size, lstm_size, - bidirectional=True, num_layers=lstm_layers) - else: - self.emb_linear = nn.Linear(EMB_SZIE + para_emb_size, 2 * lstm_size) - for t1 ,t2 in ALL_ENTITY_TYPES_PAIRS: - setattr(self, 'hidden_%s_%s' % - (t1, t2), nn.Linear(4 * lstm_size, 2 * lstm_size)) - setattr(self, 'out_%s_%s' % (t1, t2), nn.Linear(2 * lstm_size, 1)) - setattr(self, 'backoff_%s_%s' % (t1, t2), nn.Parameter( - torch.zeros(1, 2 * lstm_size))) - self.hidden_triple = nn.Linear(3 * 2 * lstm_size, 2 * lstm_size) - self.backoff_triple = nn.Parameter(torch.zeros(1, 2 * lstm_size)) - self.hidden_all = nn.Linear(4 * 2 * lstm_size, 2 * lstm_size) - self.out_triple = nn.Linear(2 * lstm_size, 1) - - def pool(self, grouped_vecs): - ''' - - :param grouped_vecs: - :return: - ''' - +import argparse +import collections +import glob +import json +import math +import numpy as np +import random +from ordered_set import OrderedSet +import os +import pickle +import shutil +from sklearn.metrics import average_precision_score +import sys +import termcolor +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn +import torch.optim as optim +from tqdm import tqdm + +from NAACL import vocabulary +from NAACL import settings +from NAACL import util + +WORD_VEC_FILE = 'wordvec/PubMed-and-PMC-w2v.txt' +WORD_VEC_NUM_LINES = 4087447 +EMB_SZIE = 200 # size of word embeddings +PARA_EMB_SIZE = 100 # size of paragraph index embeddings +PARA_EMB_MAX_SPAN = 1000 +MAX_ENTITIES_PER_TYPE = 200 +MAX_NUM_PARAGRAPHS = 200 +MAX_NUM_CANDIDATES = 10000 +ALL_ENTITY_TYPES = ['drug', 'gene', 'variant'] +ALL_ENTITY_TYPES_PAIRS = [('drug', 'gene'), ('drug', 'variant'), ('gene', 'variant')] +MAX_PARAGRAPH_LENGTH = 800 +CLIP_THRESH = 5 # Gradient clipping (on L2 norm) +JAX_DEV_PMIDS_FILE = 'jax/jax_dev_pmids.txt' +JAX_TEST_PMIDS_FILE = 'jax/jax_test_pmids.txt' + +log_file = None + +def log(msg): + print(msg, file=sys.stderr) + if log_file: + print(msg, file=log_file) + +ParaMention = collections.namedtuple( + 'ParaMention',['start', 'end', 'type', 'name']) + +class Candidate(object): + def __init__(self, drug=None, gene=None, variant=None, label=None): + self.drug = drug + self.gene = gene + self.variant = variant + self.label = label + + def remove_entity(self, i, new_label=None): + ''' + :param i: + :param new_label: + :return: Return new Candidate with entity |i| replaced with None. + ''' + triple = (self.drug, self.gene, self.variant) + new_triple = triple[:i] + (None,) + triple[i+1:] + return Candidate(*new_triple, label=new_label) + + def get_entities(self): + return (self.drug, self.gene, self.variant) + + def is_triple(self): + return self.drug and self.gene and self.variant + + def get_types(self): + out = [] + if self.drug: + out.append('drug') + if self.gene: + out.append('gene') + if self.variant: + out.append('variant') + return tuple(out) + + def __key(self): + return (self.drug, self.gene, self.variant, self.label) + + def __eq__(x, y): + return x.__key() == y.__key() + + def __hash__(self): + return hash(self.__key()) + + +class Example(object): + def __init__(self, pmid, paragraphs, mentions, triple_candidates, pair_candidates): + self.pmid = pmid + self.paragraphs = paragraphs + self.mentions = mentions + self.triple_candidates = triple_candidates + self.pair_candidates = pair_candidates + self.entities = collections.defaultdict(OrderedSet) + for m_list in mentions: + for m in m_list: + self.entities[m.type].add(m.name) + + @classmethod + def read_examples(cls, example_json_file): + results = [] + with open(os.path.join(settings.DATA_DIR, example_json_file)) as f: + for line in f: + ex = cls.read_examples(line) + results.append(ex) + + return results + + @classmethod + def read_examples(cls, example_json_str): + example_json = json.loads(example_json_str) + mentions = [[ParaMention(**mention) for mention in paragraph_mentions] + for paragraph_mentions in example_json['mentions']] + pair_candidates = {} + + for pair_key in example_json['pair_candidates']: + pair_key_tuple = tuple(json.loads(pair_key)) + pair_candidates[pair_key_tuple] = OrderedSet(Candidate(**x) + for x in example_json['pair_candidates'][pair_key]) + triple_candidates = {} + triple_candidates = [Candidate(**x) + for x in example_json['triple_candidates']] + + return cls(example_json['pmid'], + example_json['paragraphs'], + mentions, + triple_candidates, + pair_candidates) + +class Preprocessor(object): + + def __init__(self, entity_lists, vacab, device): + self.entity_lists = entity_lists + self.vocab = vacab + self.device = device + + def count_labels(self, ex, pair_only=None): + if pair_only: + candidates = ex.pair_candidates[pair_only] + else: + candidates = ex.triple_candidates + + num_pos = sum(c.label for c in candidates) + num_neg = sum(1 - c.label for c in candidates) + return num_neg, num_pos + + def shuffle_entities(self, ex): + entity_map = {} + for e_type in ex.entities: + cur_ents = ex.entities[e_type] + replacements = random.sample(self.entity_lists[e_type], len(cur_ents)) + for e_old, e_new in zip(cur_ents, replacements): + entity_map[(e_type, e_old)] = e_new + + new_paras = [] + new_mentions = [] + for p, m_list in zip(ex.paragraphs, ex.mentions): + new_para = [] + new_m_list =[] + mentions_at_loc = collections.defaultdict(list) + in_mention = [False] * len(p) + for m in m_list: + mentions_at_loc[m.start].append((m.type, m.name)) + for i in range(m.start, m.end): + in_mention[i] = True + for i in range(len(p)): + if mentions_at_loc[i]: + for e_type, name in mentions_at_loc[i]: + e_new = entity_map[(e_type, name)] + m = ParaMention(len(new_para), len(new_para)+1, e_type, name) + new_m_list.append(m) + new_para.append(e_new) + if not in_mention[i]: + new_paras.append(p[i]) + new_paras.append(new_para) + new_mentions.append(new_m_list) + return new_paras, new_mentions + + def preprocess(self, ex, pair_only): + new_paras, new_mentions = self.shuffle_entities(ex) + para_prep = [] + for para_idx, (para, m_list) in enumerate(zip(new_paras, new_mentions)): + word_idxs = torch.tensor(self.vocab.indexify_list(para), + dtype=torch.long, device=self.device) + + para_from_start = [ + para_idx / math.pow(PARA_EMB_MAX_SPAN, 2*i / (PARA_EMB_SIZE // 4)) + for i in range(PARA_EMB_SIZE // 4) + ] + + para_from_end = [ + (len(new_paras)- para_idx) / math.pow(PARA_EMB_MAX_SPAN, 2*i / (PARA_EMB_SIZE // 4)) + for i in range(PARA_EMB_SIZE // 4) + ] + + para_args = torch.cat([torch.tensor(x, dtype=torch.float, device=self.device) + for x in (para_from_start, para_from_end)]) + + para_vec = torch.cat([torch.sin(para_args), torch.cos(para_args)]) + para_prep.append((word_idxs ,para_vec, m_list)) + + # sort for pack_padded_sequence + para_prep.sort(key=lambda x:len(x[0]), reverse=True) + T, P = len(para_prep[0][0]), len(para_prep) + para_mat = torch.zeros((T, P), device=self.device, dtype=torch.long) + for i, x in enumerate(para_prep): + cur_words = x[0] + para_mat[:len(cur_words), i] = cur_words + + lenghts = torch.tensor([len(x[0]) for x in para_prep], device=self.device) + triple_labels = torch.tensor([c.label for c in ex.triple_candidates], + dtype=torch.float, device=self.device) + pair_labels = {k: torch.tensor([c.label for c in ex.pair_candidates[k]], + dtype=torch.float, device=self.device) + for k in ex.pair_candidates} + para_vecs = torch.stack([x[1] for x in para_prep], dim=0) + unlabeled_triple_cands = [Candidate(ex.drug, ex.gene, ex.variant) + for ex in ex.triple_candidates] + unlabeled_pair_cands = {k: [Candidate(ex.drug, ex.gene, ex.variant) + for ex in ex.pair_candidates[k]] + for k in ex.pair_candidates} + return (para_mat, lenghts, para_vecs, [x[2] for x in para_prep], + unlabeled_triple_cands, unlabeled_pair_cands, triple_labels, pair_labels) + +def logsumexp(inputs, dim=None, keepdim=False): + ''' + + :param inputs: A variable with any shape. + :param dim: An integer. + :param keepdim: A boolean. + :return: Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). + ''' + if dim is None: + inputs = inputs.view(-1) + dim = 0 + s, _ = torch.max(inputs, dim=dim, keepdim=True) + outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() + if not keepdim: + outputs = outputs.squeeze(dim) + return outputs + +class BackoffModel(nn.Module): + ''' + Combine triple and pairwise information. + ''' + + def __init__(self, emb_mat, lstm_size, lstm_layers, device, use_lstm=True, + use_position=True, pool_method='max', dropout_prob=0.5, vocab=None, + pair_only=None): + + super(BackoffModel, self).__init__() + self.device = device + self.use_lstm = use_lstm + self.use_position = use_position + self.pool_method - pool_method + self.embs = nn.Embedding.from_pretrained(emb_mat, freeze=False) + self.vocab = vocab + self.pair_only =pair_only + self.dropout = nn.Dropout(p=dropout_prob) + para_emb_size = PARA_EMB_SIZE if use_position else 0 + if use_lstm: + self.lstm_layers = lstm_layers + self.lstm = nn.LSTM(EMB_SZIE + para_emb_size, lstm_size, + bidirectional=True, num_layers=lstm_layers) + else: + self.emb_linear = nn.Linear(EMB_SZIE + para_emb_size, 2 * lstm_size) + for t1 ,t2 in ALL_ENTITY_TYPES_PAIRS: + setattr(self, 'hidden_%s_%s' % + (t1, t2), nn.Linear(4 * lstm_size, 2 * lstm_size)) + setattr(self, 'out_%s_%s' % (t1, t2), nn.Linear(2 * lstm_size, 1)) + setattr(self, 'backoff_%s_%s' % (t1, t2), nn.Parameter( + torch.zeros(1, 2 * lstm_size))) + self.hidden_triple = nn.Linear(3 * 2 * lstm_size, 2 * lstm_size) + self.backoff_triple = nn.Parameter(torch.zeros(1, 2 * lstm_size)) + self.hidden_all = nn.Linear(4 * 2 * lstm_size, 2 * lstm_size) + self.out_triple = nn.Linear(2 * lstm_size, 1) + + def pool(self, grouped_vecs): + ''' + + :param grouped_vecs: + :return: + ''' + if self.pool_method == 'mean': + return torch.stack([torch.mean(g, dim=0) for g in grouped_vecs]) + elif self.pool_method == 'sum': + return torch.stack([torch.sum(g, dim=0) for g in grouped_vecs]) + elif self.pool_method == 'max': + return torch.stack([torch.max(g, dim=0)[0] for g in grouped_vecs]) + elif self.pool_method == 'softmax': + return torch.stack([logsumexp(g, dim=0) for g in grouped_vecs]) + raise NotImplementedError + + def forward(self, word_idx_mat, lens, para_vecs, mentions, + triple_candidates, pair_candidates): + ''' + + :param word_idx_mat: list of word indices, size(T, P) + :param lens: list of paragraph lengths, size(P) + :param para_vecs: list of paragraph vectors, size(P, pe) + :param mentions: list of list of ParaMention + :param triple_candidates: list of unlabeled Candidate + :param pair_candidates: list of unlabeled Candidate + :return: + ''' + T, P = word_idx_mat.shape # T=num_toks, P=num_paras + + # Organize the candidate pairs and triples + pair_to_idx = {} + pair_sets = collections.defaultdict(set) + for(t1, t2), cands in pair_candidates.items(): + pair_to_idx[(t1, t2)] = {c: i for i, c, in enumerate(cands)} + for c in cands: + pair_sets[(t1, t2)].add(c) + triple_to_idx = {c: i for i, c in enumerate(triple_candidates)} + + # Build local embeddings of each word + embs = self.embs(word_idx_mat) # T, P, e + if self.use_position: + para_embs = para_vecs.unsqueeze(0).expand(T, -1,-1) # T, P, pe + embs = torch.cat([embs, para_embs], dim=2) # T, P, e + pe + if self.use_lstm: + lstm_in = rnn.pack_padded_sequence(embs, lens) # T, P, e + pe + lstm_out_packed, _ = self.lstm(lstm_in) + embs, _ = rnn.pad_packed_sequence(lstm_out_packed) # T, P, 2*h + else: + embs = self.emb_linear(embs) # T, P, 2*h + + # Gather co-occurring mention pairs and triples + pair_inputs = {(t1, t2):[[] for i in range(len(cands))] + for(t1, t2), cands in pair_candidates.items()} + triple_inputs = [[] for i in range(len(triple_candidates))] + + for para_idx, m_list in enumerate(mentions): + typed_mentions = collections.defaultdict(list) + for m in m_list: + typed_mentions[m.type].append(m) + for t1, t2 in ALL_ENTITY_TYPES_PAIRS: + if self.pair_only and self.pair_only !=(t1 ,t2): + continue + for m1 in typed_mentions[t1]: + for m2 in typed_mentions[t2]: + query_cand = Candidate(**{t1: m1.name, t2: m2.name}) + if query_cand in pair_to_idx[(t1, t2)]: + idx = pair_to_idx[(t1, t2)][query_cand] + cur_vecs = torch.cat([embs[m1.start, para_idx, :], + embs[m2.start, para_idx, :]]) # 4*h + pair_inputs[(t1, t2)][idx].append(cur_vecs) + if self.pair_only: + continue + for m1 in typed_mentions['drug']: + for m2 in typed_mentions['gene']: + for m3 in typed_mentions['variant']: + query_cand = Candidate(m1.name, m2.name, m3.name) + if query_cand in triple_to_idx: + idx = triple_to_idx[query_cand] + cur_vecs = torch.cat( + [embs[m1.start, para_idx, :], + embs[m2.start, para_idx, :], + embs[m3.start, para_idx, :]]) # 6*h + triple_inputs[idx].append(cur_vecs) + + # Compute local mention pair/triple representations + pair_vecs = {} + for t1, t2 in ALL_ENTITY_TYPES_PAIRS: + if self.pair_only and self.pair_only != (t1, t2): + continue + cur_group_sizes = [len(vecs) for vecs in pair_inputs[(t1, t2)]] + if sum(cur_group_sizes) > 0: + cur_stack = torch.stack([ + v for vecs in pair_inputs[(t1, t2)] for v in vecs]) # M, 4*h + cur_m_reps = getattr(self, 'hidden_%s_%s' % + (t1, t2))(cur_stack) # M, 2*h + cur_pair_grouped_vecs = list(torch.split(cur_m_reps, cur_group_sizes)) + for i in range(len(cur_pair_grouped_vecs)): + if cur_pair_grouped_vecs[i].shape[0] == 0: # Back off + cur_pair_grouped_vecs[i] = getattr(self, + 'backoff_%s_%s' % (t1, t2)) + else: + cur_pair_grouped_vecs = [getattr(self, 'backoff_%s_%s' % (t1, t2)) + for vecs in pair_inputs[(t1, t2)]] + pair_vecs[(t1, t2)] = torch.tanh( + self.pool(cur_pair_grouped_vecs)) # P, 2*h + + if not self.pair_only: + triple_group_sizes = [len(vecs) for vecs in triple_inputs] + if sum(triple_group_sizes) > 0: + triple_stack = torch.stack([ + v for vecs in triple_inputs for v in vecs]) # M, 6*h + triple_m_reps = self.hidden_triple(triple_stack) # M, 2*h + triple_grouped_vecs = list( + torch.split(triple_m_reps, triple_group_sizes)) + for i in range(len(triple_grouped_vecs)): + if triple_grouped_vecs[i].shape[0] == 0: # back off + triple_grouped_vecs[i] = self.backoff_triple + + else: + triple_grouped_vecs = [self.backoff_triple for vecs in triple_inputs] + triple_vecs = torch.tanh(self.pool(triple_grouped_vecs)) # C, 2*h + + # Score candidate pairs + pair_logits = {} + for t1, t2 in ALL_ENTITY_TYPES_PAIRS: + if self.pair_only and self.pair_only != (t1, t2): + continue + pair_logits[(t1, t2)] = getattr(self, 'out_%s_%s' % (t1, t2))( + pair_vecs[(t1, t2)])[:, 0] #M + if self.pair_only: + return None, pair_logits + + # Score candidate triples + pair_feats_per_triple = [[], [], []] + for c in triple_candidates: + for i in range(3): + pair = c.remove_entity(i) + t1, t2 = pair.get_types() + pair_idx = pair_to_idx[(t1, t2)](pair) + pair_feats_per_triple[i].append( + pair_vecs[(t1, t2)][pair_idx, :]) # 2*h + triple_feats = torch.cat( + [torch.stack(pair_feats_per_triple[0]), + torch.stack(pair_feats_per_triple[1]), + torch.stack(pair_feats_per_triple[2]), + triple_vecs], + dim=1) # C, 8*h + final_hidden = F.relu(self.hidden_all(triple_feats)) # C, 2*h + triple_logits = self.out_triple(final_hidden)[:, 0] # C + return triple_logits, pair_logits + + +def get_entity_lists(): + entity_lists = {} + for et in ALL_ENTITY_TYPES: + entity_lists[et] = ['__%s__' % et + for i in range(MAX_ENTITIES_PER_TYPE)] + # Can streamline, since we're just using single placeholder per entity type + return entity_lists + +def count_labels(name, data, preprocessor, pair_only=None): + num_neg, num_pos = 0, 0 + for ex in data: + cur_neg, cur_pos = preprocessor.count_labels(ex, pair_only=pair_only) + num_neg += cur_neg + num_pos += cur_pos + log('%s data: +%d, -%d' % (name, num_pos, num_neg)) + return num_neg, num_pos + +def print_data_stats(data, name): + print(name) + print(' Max num paragraphs: %d' % max(len(ex.paragraphs) for ex in data)) + print(' Max num triple candidates: %d' % max( + len(ex.triple_candidates) for ex in data)) + +def init_word_vecs(device, vocab, all_zero=False): + num_pretrained = 0 + embs = torch.zeros((len(vocab), EMB_SZIE), dtype=torch.float, device=device) + if not all_zero: + with open(os.path.join(settings.DATA_DIR,WORD_VEC_FILE)) as f: + for line in tqdm(f, total=WORD_VEC_NUM_LINES): + toks = line.strip().split(' ') + if len(toks) != EMB_SZIE + 1: + continue + word = toks[0] + if word in vocab: + idx = vocab.get_index(word) + embs[idx, :] = torch.tensor([float(x) for x in toks[1:]], + dtype=torch.float, device=device) + num_pretrained += 1 + log('Found pre-trained vectors for %d/%d = %.2f%% words' % ( + num_pretrained, len(vocab), 100*0 * num_pretrained /len(vocab))) + + return embs + +def train(model, train_data, dev_data, preprocessor, num_epochs, lr, ckpt_iters, + downsample_to, out_dir, lr_decay=1.0, pos_weight=None, use_pair_loss=True, + pair_only=None): + model.train() + if ckpt_iters > len(train_data): + ckpt_iters = len(train_data) # Checkpoint at least once per epoch + loss_func = nn.BCEWithLogitsLoss(pos_weight=pos_weight) + params = [p for p in model.paraments() if p.requires_grad] + optimizer = optim.Adam(params, lr=lr) + scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_decay) + train_data = list(train_data) # Copy before shuffling + num_iters = 0 + best_ap = 0.0 # Choose checkpoint based on dev average precision + train_loss = 0.0 + for t in range(num_epochs): + t0 = time.time() + random.shuffle(train_data) + if not downsample_to: + cur_train = tqdm(train_data) + else: + cur_train = train_data # tqdm is annoyingn on downsampled data + for ex in cur_train: + model.zero_grad() + ex_torch = preprocessor.preprocess(ex, pair_only) + triple_labels, pair_labels = ex_torch[-2:] + triple_logits, pair_logits = model(*ex_torch[:-2]) + if pair_only: + loss = loss_func(pair_logits[pair_only], pair_labels[pair_only]) + else: + loss = loss_func(triple_logits, triple_labels) + if use_pair_loss: + for t1, t2 in ALL_ENTITY_TYPES_PAIRS: + loss += loss_func(pair_logits[(t1, t2)], pair_labels[(t1, t2)]) + train_loss += loss.item() + loss.backward() + torch.nn.utils.clip_grad_norm(model.paraments(),CLIP_THRESH) + optimizer.step() + num_iters += 1 + if num_iters % ckpt_iters == 0: + model.eval() + dev_preds, dev_loss = predict( + model, dev_data, preprocessor, loss_func=loss_func, + use_pair_loss=use_pair_loss, pair_only=pair_only) + log('Iter %d: train loss = %.6f, dev loss = %.6f' % ( + num_iters, train_loss / ckpt_iters, dev_loss)) + + train_loss = 0.0 + p_doc, r_doc, f1_doc, ap_doc = evaluate(dev_data, dev_preds, + pair_only=pair_only) + log(' Document-level : p=%.2f%% r=%.2f%% f1=%.2f%% ap=%.2f%%' % ( + 100 * p_doc, 100 * r_doc, 100 * f1_doc, 100 * ap_doc)) + if out_dir: + save_model(model, num_iters, out_dir) + model.train() + scheduler.step() + t1 = time.time() + log('Epoch %s: took %s' % (str(t).rjust(3), util.secs_to_str(t1 - t0))) + +def predict(model, data, preprocessor, loss_func=None, use_pair_loss=True, pair_only=None): + loss = 0.0 + preds = [] + with torch.no_grad(): + for ex in data: + all_logits = [] + ex_torch = preprocessor.preprocess(ex, pair_only) + triple_labels, pair_labels = ex_torch[-2:] + triple_logits, pair_logits = model(*ex_torch[:-2]) + if loss_func: + if pair_only: + loss += loss_func(pair_logits[pair_only], pair_labels[pair_only]) + else: + loss += loss_func(triple_logits, triple_labels) + if use_pair_loss: + for t1 ,t2 in ALL_ENTITY_TYPES_PAIRS: + loss += loss_func(pair_logits[(t1 ,t2)], pair_labels[(t1, t2)]) + if pair_only: + cur_pred = [1 / (1 + np.exp(-z.item())) for z in pair_logits[pair_only]] + else: + cur_pred = [1 / (1 + np.exp(-z.item())) for z in pair_logits[pair_only]] + preds.append(cur_pred) + out = [preds] + if loss_func: + out.append(loss / len(data)) + if len(out) == 1: + return out[0] + return out + + +COLORS = {'drug': 'red', 'variant': 'cyan', 'gene': 'green'} + +def pprint_example(ex, f=sys.stdout): + print('PMID %s' % ex.pmid, file=f) + for para_idx, (paragraph, m_list) in enumerate(zip(ex.paragraphs, ex.mentions)): + word_to_type = {} + for m in m_list: + for i in range(m.start, m.end): + word_to_type[i] = m.type + para_toks = [] + for i in range(len(paragraph)): + if i in word_to_type: + para_toks.append(termcolor.colored( + paragraph[i], COLORS[word_to_type[i]])) + else: + para_toks.append(paragraph[i]) + print(' Paragraph %d: %s' % (para_idx, ' '.join(para_toks)), file=f) + +def evaluate(data, probs, name=None, threshold=0.5, pair_only=None): + def get_candidates(ex): + if pair_only: + return ex.pair_candidates[pair_only] + else: + return ex.triple_candidates + if name: + log('== %s, document-level: %d documents, %d candidates (+%d, -%d) ==' % ( + name, len(data), sum(len(get_candidates(ex)) for ex in data), + sum(1 for ex in data for c in get_candidates(ex) if c.label == 1), + sum(1 for ex in data for c in get_candidates(ex) if c.label == 0))) + tp = fp = fn = 0 + y_true = [] + y_pred = [] + for ex, prob_list in zip(data, probs): + for c, prob in zip(get_candidates(ex), prob_list): + y_true.append(c.label) + y_pred.append(prob) + pred = int(prob > threshold) + if pred == 1: + if c.label == 1: + tp += 1 + else: + fp += 1 + else: + if c.label == 1: + fn += 1 + ap = average_precision_score(y_true, y_pred) + if name: + log(util.get_prf(tp, fp, fn, get_str=True)) + log('AvgPrec : %.2f%%' % (100.0 * ap)) + p, r, f = util.get_prf(tp, fp, fn) + return p, r, f, ap + + +def predict_write(model, data, preprocessor, out_dir, ckpt, data_name, pair_only): + if out_dir: + if ckpt: + out_path = os.path.join(out_dir, 'pred_%s_%07d.tsv' % (data_name, ckpt)) + else: + out_path = os.path.join(out_dir, 'pred_%s.tsv' % data_name) + # Only one pprint necessary + pprint_out = os.path.join(out_dir, 'dev_pprint.txt') + else: + pprint_out = None + pred = predict(model, tqdm(data), preprocessor, pair_only=pair_only) + pprint_predictions(data, pred, preprocessor, fn=pprint_out) + if out_path: + write_predictions(data, pred, out_path, pair_only=pair_only) + + +def pprint_predictions(data, preds, preprocessor, threshold=0.5, fn=None): + if fn: + f = open(fn, 'w') + else: + f = sys.stdout + for i, (ex, pred_list) in enumerate(zip(data, preds)): + pprint_example(ex, f=f) + new_paras, new_mentions = ex.paragraphs, ex.mentions + for j, (c, pred) in enumerate(zip(ex.triple_candidates, pred_list)): + pred_label = pred > threshold + print(' (%s, %s, %s): pred=%s (p=%.4f), gold=%s, correct=%s' % ( + c.drug, c.gene, c.variant, pred_label, pred, + c.label == 1, pred_label == (c.label == 1)), file=f) + print('', file=f) + if fn: + f.close() + +def write_predictions(data, preds, fn, pair_only=None): + i = 0 + with open(fn, 'w') as f: + for ex, pred_list in zip(data, preds): + if pair_only: + candidates = ex.pair_candidates[pair_only] + else: + candidates = ex.triple_candidates + for c, pred in zip(candidates, pred_list): + print('%d\t%s\t%s\t%s\t%s\t%.6f' % ( + i, ex.pmid, c.drug, c.gene, c.variant, pred), file=f) + i += 1 + + +def make_vocab(train_data, entity_lists, unk_thresh): + vocab = vocabulary.Vocabulary(unk_threshold=unk_thresh) + for ents in list(entity_lists.values()): + for e in ents: + vocab.add_word_hard(e) + for ex in tqdm(train_data): + for p, m_list in zip(ex.paragraphs, ex.mentions): + in_mention = [False] * len(p) + for m in m_list: + for i in range(m.start, m.end): + in_mention[i] = True + for i, w in enumerate(p): + if not in_mention[i]: + vocab.add_word(w) + return vocab + + +def save_model(model, num_iters, out_dir): + fn = os.path.join(out_dir, 'model.%07d.pth' % num_iters) + torch.save(model.state_dict(), fn) + +def load_model(model, load_dir, device, load_ckpt): + # if not load_ckpt: + # with open(os.path.join(load_dir, 'best_model.txt')) as f: + # load_ckpt = int(f.read().strip().split('\t')[0]) + fn = os.path.join(load_dir, 'model.%07d.pth' % load_ckpt) + log('Loading model from %s' % fn) + model.load_state_dict(torch.load(fn, map_location=device)) + +def predict_write(model, data, preprocessor, out_dir, ckpt, data_name, pair_only): + if out_dir: + if ckpt: + out_path = os.path.join(out_dir, 'pred_%s_%07d.tsv' % (data_name, ckpt)) + else: + out_path = os.path.join(out_dir, 'pred_%s.tsv' % data_name) + # Only one pprint necessary + pprint_out = os.path.join(out_dir, 'dev_pprint.txt') + else: + pprint_out = None + pred = predict(model, tqdm(data), preprocessor, pair_only=pair_only) + pprint_predictions(data, pred, preprocessor, fn=pprint_out) + if out_path: + write_predictions(data, pred, out_path, pair_only=pair_only) + + +def get_ds_train_dev_pmids(pmid_file): + with open(os.path.join(settings.DATA_DIR, pmid_file)) as f: + pmids = sorted([pmid.strip() for pmid in f if pmid.strip()]) + random.shuffle(pmids) + num_train = int(round(len(pmids) * 0.7)) + num_train_dev = int(round(len(pmids) * 0.8)) + train_pmids = set(pmids[:num_train]) + dev_pmids = set(pmids[num_train:num_train_dev]) + return train_pmids, dev_pmids + + +def parse_args(args): + parser = argparse.ArgumentParser() + # Required params + # parser.add_argument('para_file', help='JSON object storing paragraph text') + # parser.add_argument('mention_file', help='List of mentions for relevant paragraphs') + parser.add_argument('--ds-train-dev-file', help='Training examples') + parser.add_argument('--jax-dev-test-file', help='Dev examples') + parser.add_argument('--init-pmid-file', default='pmid_lists/init_pmid_list.txt', + help='Dev examples') + + # Model architecture + parser.add_argument('--lstm-size', '-c', default=200, + type=int, help='LSTM hidden state size.') + parser.add_argument('--lstm-layers', '-l', default=1, + type=int, help='LSTM number of layers.') + parser.add_argument('--pool', '-p', choices=['softmax', 'max', 'mean', 'sum'], default='softmax', + help='How to pool across mentions') + parser.add_argument('--no-position', action='store_true', + help='Ablate paragraph index encodings') + parser.add_argument('--no-lstm', action='store_true', help='Ablate LSTM') + # Training + parser.add_argument('--num-epochs', '-T', type=int, + default=10, help='Training epochs') + parser.add_argument('--learning-rate', '-r', type=float, + default=1e-5, help='Learning rate.') + parser.add_argument('--dropout-prob', '-d', type=float, + default=0.5, help='Dropout probability') + parser.add_argument('--lr-decay', '-g', type=float, default=1.0, + help='Decay learning rate by this much each epoch.') + parser.add_argument('--balanced', '-b', action='store_true', + help='Upweight positive examples to balance dataset') + parser.add_argument('--pos-weight', type=float, default=None, + help='Upweight postiive examples by this much') + parser.add_argument('--use-pair-loss', action='store_true', + help="Multi-task on pair objective") + # Data + #parser.add_argument('--data-cache', default=DEFAULT_CACHE) + parser.add_argument('--data-cache', default=None) + parser.add_argument('--rng-seed', default=0, type=int, help='RNG seed') + parser.add_argument('--torch-seed', default=0, + type=int, help='torch RNG seed') + parser.add_argument('--downsample-to', default=None, type=int, + help='Downsample to this many examples per split') + parser.add_argument('--unk-thresh', '-u', default=5, type=int, + help='Treat words with fewer than this many counts as .') + parser.add_argument('--print-dev', action='store_true', + help='Test on dev data') + parser.add_argument('--jax', action='store_true', help='Test on JAX data') + parser.add_argument('--jax-out', default='pred_jax.tsv') + parser.add_argument('--text-level', choices=['document', 'paragraph', 'sentence'], + default='document', help='Split documents paragraph-wise or sentence-wise') + parser.add_argument('--pair-only', default=None, + help='Comma-separated pair of entities to focus on only') + # CPU vs. GPU + parser.add_argument('--cpu-only', action='store_true', + help='Run on CPU only') + parser.add_argument('--gpu-id', type=int, default=0, + help='GPU ID (default=0)') + # Saving and loading + parser.add_argument('--out-dir', '-o', default=None, + help='Where to write all output') + parser.add_argument('--ckpt-iters', '-i', default=10000, type=int, + help='Checkpoint after this many training steps.') + parser.add_argument( + '--load', '-L', help='Directory to load model parameters and vocabulary') + parser.add_argument('--load-ckpt', type=int, default=None, + help='Which checkpoint to use (default: use best_model.txt)') + parser.add_argument('--try-all-checkpoints', action='store_true', + help='Make predictions for every checkpoint') + parser.add_argument('--data-dir', help='root data directory') + # Other + parser.add_argument('--no-w2v', action='store_true', + help='No pre-trained word vectors') + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args(args) + + +def get_all_checkpoints(out_dir): + fns = glob.glob(os.path.join(out_dir, 'model.*.pth')) + return sorted([int(os.path.basename(x).split('.')[1]) for x in fns]) + +def run(OPTS, device): + # Process pair-only mode + pair_only = None + if OPTS.pair_only: + pair_only = tuple(OPTS.pair_only.split(',')) + if pair_only not in ALL_ENTITY_TYPES_PAIRS: + raise ValueError('Bad value for pair_only: %s' % OPTS.pair_only) + entity_lists = get_entity_lists() + # Read data + train_pmids_set, dev_ds_pmids_set = get_ds_train_dev_pmids( + OPTS.init_pmid_file) + ds_train_dev_data = Example.read_examples(OPTS.ds_train_dev_file) + # Filter out examples that doesn't contain pair or triple candidates + if pair_only: + ds_train_dev_data = [x for x in ds_train_dev_data if pair_only in + x.pair_candidates and x.pair_candidates[pair_only]] + else: + ds_train_dev_data = [x for x in ds_train_dev_data if x.triple_candidates] + train_data = [x for x in ds_train_dev_data if x.pmid in train_pmids_set] + dev_ds_data = [x for x in ds_train_dev_data if x.pmid in dev_ds_pmids_set] + random.shuffle(train_data) + random.shuffle(dev_ds_data) + + jax_dev_test_data = Example.read_examples(OPTS.jax_dev_test_file) + if pair_only: + jax_dev_test_data = [x for x in jax_dev_test_data if pair_only in + x.pair_candidates and x.pair_candidates[pair_only]] + else: + jax_dev_test_data = [x for x in jax_dev_test_data if x.triple_candidates] + random.shuffle(jax_dev_test_data) + + with open(os.path.join(settings.DATA_DIR, JAX_DEV_PMIDS_FILE)) as f: + dev_jax_pmids_set = set(x.strip() for x in f if x.strip()) + with open(os.path.join(settings.DATA_DIR, JAX_TEST_PMIDS_FILE)) as f: + test_pmids_set = set(x.strip() for x in f if x.strip()) + + dev_jax_data = [x for x in jax_dev_test_data if x.pmid in dev_jax_pmids_set] + test_data = [x for x in jax_dev_test_data if x.pmid in test_pmids_set] + log('Read %d train, %d dev dist sup, %d dev jax, %d test examples' % + (len(train_data), len(dev_ds_data), len(dev_jax_data), len(test_data))) + + vocab = make_vocab(train_data, entity_lists, OPTS.unk_thresh) + log('Vocab size = %d.' % len(vocab)) + preprocessor = Preprocessor(entity_lists, vocab, device) + num_neg, num_pos = count_labels('train', train_data, preprocessor, + pair_only=pair_only) + word_vecs = init_word_vecs(device, vocab, all_zero=OPTS.load or OPTS.no_w2v) + log('Finished reading data.') + + # Run model + model = BackoffModel( + word_vecs, OPTS.lstm_size, OPTS.lstm_layers, device, + use_lstm=not OPTS.no_lstm, use_position=not OPTS.no_position, + pool_method=OPTS.pool, dropout_prob=OPTS.dropout_prob, + vocab=vocab, pair_only=pair_only).to(device=device) + if OPTS.load: + load_model(model, OPTS.load, device, OPTS.load_ckpt) + if OPTS.num_epochs > 0: + log('Starting training.') + pos_weight = None + if OPTS.balanced: + pos_weight = torch.tensor(float(num_neg) / num_pos, device=device) + elif OPTS.pos_weight: + pos_weight = torch.tensor(OPTS.pos_weight, device=device) + train(model, train_data, dev_ds_data, preprocessor, OPTS.num_epochs, + OPTS.learning_rate, OPTS.ckpt_iters, OPTS.downsample_to, OPTS.out_dir, + pos_weight=pos_weight, lr_decay=OPTS.lr_decay, + use_pair_loss=OPTS.use_pair_loss, pair_only=pair_only) + log('Finished training.') + model.eval() + if OPTS.try_all_checkpoints: + ckpts = get_all_checkpoints(OPTS.out_dir) + else: + ckpts = [None] + for ckpt in ckpts: + if ckpt: + print('== Checkpoint %s == ' % ckpt, file=sys.stderr) + load_model(model, OPTS.out_dir, device, ckpt) + predict_write(model, dev_jax_data, preprocessor, + OPTS.out_dir, ckpt, 'dev', pair_only) + predict_write(model, test_data, preprocessor, + OPTS.out_dir, ckpt, 'test', pair_only) + + +def main(OPTS): + if OPTS.out_dir: + if os.path.exists(OPTS.out_dir): + shutil.rmtree(OPTS.out_dir) + os.makedirs(OPTS.out_dir) + global log_file + log_file = open(os.path.join(OPTS.out_dir, 'log.txt'), 'w') + log(OPTS) + random.seed(OPTS.rng_seed) + torch.manual_seed(OPTS.torch_seed) + if OPTS.cpu_only: + device = torch.device('cpu') + else: + device = torch.device('cuda:%d' % OPTS.gpu_id) + try: + run(OPTS, device) + finally: + if log_file: + log_file.close() + + +if __name__ == '__main__': + OPTS = parse_args(sys.argv[1:]) + main(OPTS) diff --git a/NAACL/ensemble.py b/NAACL/ensemble.py index 89e3133..790a03d 100644 --- a/NAACL/ensemble.py +++ b/NAACL/ensemble.py @@ -1,86 +1,86 @@ -'''Ensemble some predictions. ''' -import argparse -import collections -import math -from scipy.special import logsumexp -import sys - -MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio'] - -def parse_args(args): - parser = argparse.ArgumentParser() - parser.add_argument('mode', choices=MODES) - parser.add_argument('files', nargs='+') - parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')], - help='Comma-separated lit of multiplizer per file') - parser.add_argument('--out-file', '-o', default=None, help='Where to write all output') - - if len(sys.argv) == 1: - parser.print_help() - sys.exit(1) - return parser.parse_args(args) - -def read_preds(fn): - preds = [] - with open(fn) as f: - for line in f: - idx, pmid, drug, gene, variant, prob = line.strip().split('\t') - prob = float(prob) - preds.append((pmid, drug, gene, variant, prob)) - - return preds - -def main(OPTS): - preds_all = [read_preds(fn) for fn in OPTS.files] - groups = collections.defaultdict(list) - for i, preds in enumerate(preds_all): - if OPTS.weights: - weight = OPTS.weights[i] - else: - weight = 1.0 - for pmid, drug, gene, variant, prob in preds: - groups[(pmid, drug, gene, variant)].append(weight * prob) - - results = [] - for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()): - if OPTS.mode == 'mean': - prob = sum(prob_list) / len(prob_list) - elif OPTS.mode == 'max': - prob = max(prob_list) - elif OPTS.mode == 'logsumexp': - prob = logsumexp(prob_list) - elif OPTS.mode == 'noisy_or': - prob_no_rel = 1.0 - for p in prob_list: - prob_no_rel *= 1.0 - p - prob =1.0 - prob_no_rel - elif OPTS.mode == 'log_noisy_or': - log_prob_no_rel = 0.0 - for p in prob_list: - if p < 1.0: - log_prob_no_rel += math.log(1.0 - p) - else: - log_prob_no_rel -= 1000000 - prob = -log_prob_no_rel - elif OPTS.mode == 'odds_ratio': - cur_log_odds = 0.0 - for p in prob_list: - cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000) - prob = cur_log_odds - else: - raise ValueError(OPTS.mode) - results.append((i, pmid, drug, gene, variant, prob)) - - with open(OPTS.out_file, 'w') as f: - for item in results: - f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item)) - -if __name__ == '__main__': - OPTS = parse_args(sys.argv[1:]) - main(OPTS) - - - - - - +'''Ensemble some predictions. ''' +import argparse +import collections +import math +from scipy.special import logsumexp +import sys + +MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio'] + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument('mode', choices=MODES) + parser.add_argument('files', nargs='+') + parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')], + help='Comma-separated lit of multiplizer per file') + parser.add_argument('--out-file', '-o', default=None, help='Where to write all output') + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args(args) + +def read_preds(fn): + preds = [] + with open(fn) as f: + for line in f: + idx, pmid, drug, gene, variant, prob = line.strip().split('\t') + prob = float(prob) + preds.append((pmid, drug, gene, variant, prob)) + + return preds + +def main(OPTS): + preds_all = [read_preds(fn) for fn in OPTS.files] + groups = collections.defaultdict(list) + for i, preds in enumerate(preds_all): + if OPTS.weights: + weight = OPTS.weights[i] + else: + weight = 1.0 + for pmid, drug, gene, variant, prob in preds: + groups[(pmid, drug, gene, variant)].append(weight * prob) + + results = [] + for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()): + if OPTS.mode == 'mean': + prob = sum(prob_list) / len(prob_list) + elif OPTS.mode == 'max': + prob = max(prob_list) + elif OPTS.mode == 'logsumexp': + prob = logsumexp(prob_list) + elif OPTS.mode == 'noisy_or': + prob_no_rel = 1.0 + for p in prob_list: + prob_no_rel *= 1.0 - p + prob =1.0 - prob_no_rel + elif OPTS.mode == 'log_noisy_or': + log_prob_no_rel = 0.0 + for p in prob_list: + if p < 1.0: + log_prob_no_rel += math.log(1.0 - p) + else: + log_prob_no_rel -= 1000000 + prob = -log_prob_no_rel + elif OPTS.mode == 'odds_ratio': + cur_log_odds = 0.0 + for p in prob_list: + cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000) + prob = cur_log_odds + else: + raise ValueError(OPTS.mode) + results.append((i, pmid, drug, gene, variant, prob)) + + with open(OPTS.out_file, 'w') as f: + for item in results: + f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item)) + +if __name__ == '__main__': + OPTS = parse_args(sys.argv[1:]) + main(OPTS) + + + + + + diff --git a/NAACL/out/log.txt b/NAACL/out/log.txt new file mode 100644 index 0000000..97e48ca --- /dev/null +++ b/NAACL/out/log.txt @@ -0,0 +1 @@ +Namespace(balanced=False, ckpt_iters=10000, cpu_only=False, data_cache=None, data_dir=None, downsample_to=None, dropout_prob=0.5, ds_train_dev_file='C:\\workspace\\GNNs\\data\\examples\\sentence\\ds_train_dev.txt', gpu_id=0, init_pmid_file='pmid_lists/init_pmid_list.txt', jax=False, jax_dev_test_file='C:\\workspace\\GNNs\\data\\examples\\sentence\\jax_dev_test.txt', jax_out='pred_jax.tsv', learning_rate=1e-05, load=None, load_ckpt=None, lr_decay=1.0, lstm_layers=1, lstm_size=200, no_lstm=False, no_position=False, no_w2v=False, num_epochs=10, out_dir='.\\out\\', pair_only=None, pool='softmax', pos_weight=None, print_dev=False, rng_seed=0, text_level='sentence', torch_seed=0, try_all_checkpoints=False, unk_thresh=5, use_pair_loss=False) diff --git a/NAACL/prune_pred_gv_map.py b/NAACL/prune_pred_gv_map.py index 5b15f6c..b5d6ae2 100644 --- a/NAACL/prune_pred_gv_map.py +++ b/NAACL/prune_pred_gv_map.py @@ -1,51 +1,51 @@ -"""Prune model predictions with rule-based G-V linker.""" -import argparse -import collections -import os -import sys - -from NAACL import settings - -OPTS = None - -GV_MAP_FILE = 'gene_var/gene_to_var.tsv' - -def prep_gv_mapping(): - var_to_gene= {} - gene_to_var= collections.defaultdict(set) - pmid_to_gv = collections.defaultdict(set) - pmid_gv_map = {} - with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f: - for line in f: - pmid, variant, gene = line.strip().strip() - gene = gene.lower() - var_to_gene[(pmid, variant)] = gene - gene_to_var[(pmid, gene)].add(variant) - pmid_to_gv[pmid].add((gene, variant)) - - return var_to_gene, gene_to_var, pmid_to_gv - -def parse_args(args): - parser = argparse.ArgumentParser() - parser.add_argument('pred_file') - parser.add_argument('out_file') - if len(args) == 0: - parser.print_help() - sys.exit(1) - return parser.parse_args(args) - -def main(OPTS): - var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping() - with open(OPTS.pred_file) as fin: - with open(OPTS.out_file) as fout: - for line in fin: - idx, pmid, d, g, v, prob = line.strip().split('\t') - if(pmid, v) not in var_to_gene: - continue - g_linked = var_to_gene[(pmid, v)] - if g_linked == g: - fout.write(line) - -if __name__ == '__main__': - OPTS = parse_args(sys.argv[1:]) +"""Prune model predictions with rule-based G-V linker.""" +import argparse +import collections +import os +import sys + +from NAACL import settings + +OPTS = None + +GV_MAP_FILE = 'gene_var/gene_to_var.tsv' + +def prep_gv_mapping(): + var_to_gene= {} + gene_to_var= collections.defaultdict(set) + pmid_to_gv = collections.defaultdict(set) + pmid_gv_map = {} + with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f: + for line in f: + pmid, variant, gene = line.strip().strip() + gene = gene.lower() + var_to_gene[(pmid, variant)] = gene + gene_to_var[(pmid, gene)].add(variant) + pmid_to_gv[pmid].add((gene, variant)) + + return var_to_gene, gene_to_var, pmid_to_gv + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument('pred_file') + parser.add_argument('out_file') + if len(args) == 0: + parser.print_help() + sys.exit(1) + return parser.parse_args(args) + +def main(OPTS): + var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping() + with open(OPTS.pred_file) as fin: + with open(OPTS.out_file) as fout: + for line in fin: + idx, pmid, d, g, v, prob = line.strip().split('\t') + if(pmid, v) not in var_to_gene: + continue + g_linked = var_to_gene[(pmid, v)] + if g_linked == g: + fout.write(line) + +if __name__ == '__main__': + OPTS = parse_args(sys.argv[1:]) main(OPTS) \ No newline at end of file diff --git a/NAACL/settings.py b/NAACL/settings.py index 54c9bba..a387caa 100644 --- a/NAACL/settings.py +++ b/NAACL/settings.py @@ -1 +1 @@ -DATA_DIR = 'data' +DATA_DIR = 'data' diff --git a/NAACL/util.py b/NAACL/util.py index a98022c..c1e79dc 100644 --- a/NAACL/util.py +++ b/NAACL/util.py @@ -1,41 +1,41 @@ -SECS_PER_MIN = 60 -SECS_PER_HOUR = SECS_PER_MIN * 60 -SECS_PER_DAY = SECS_PER_HOUR * 24 - -def secs_to_str(secs): - days = int(secs) // SECS_PER_DAY - secs -= days * SECS_PER_DAY - hours = int(secs) // SECS_PER_HOUR - secs -= hours * SECS_PER_HOUR - mins = int(secs) // SECS_PER_MIN - secs -= mins * SECS_PER_MIN - if days > 0: - return '%dd%02dh%02dm' % (days, hours, mins) - elif hours > 0: - return '%dh%02dm%02ds' % (hours, mins, int(secs)) - elif mins > 0: - return '%dm%02ds' % (mins, int(secs)) - elif secs >= 1: - return '%.1fs' % secs - return '%.2fs' % secs - -def get_prf(tp, fp, fn, get_str=False): - """Get precision, recall, f1 from true pos, false pos, false neg.""" - if tp + fp == 0: - precision = 0 - else: - precision = float(tp) / (tp + fp) - if tp + fn == 0: - recall = 0 - else: - recall = float(tp) / (tp + fn) - if precision + recall == 0: - f1 = 0 - else: - f1 = 2 * precision * recall / (precision + recall) - if get_str: - return '\n'.join([ - 'Precision: %.2f%%' % (100.0 * precision), - 'Recall : %.2f%%' % (100.0 * recall), - 'F1 : %.2f%%' % (100.0 * f1)]) +SECS_PER_MIN = 60 +SECS_PER_HOUR = SECS_PER_MIN * 60 +SECS_PER_DAY = SECS_PER_HOUR * 24 + +def secs_to_str(secs): + days = int(secs) // SECS_PER_DAY + secs -= days * SECS_PER_DAY + hours = int(secs) // SECS_PER_HOUR + secs -= hours * SECS_PER_HOUR + mins = int(secs) // SECS_PER_MIN + secs -= mins * SECS_PER_MIN + if days > 0: + return '%dd%02dh%02dm' % (days, hours, mins) + elif hours > 0: + return '%dh%02dm%02ds' % (hours, mins, int(secs)) + elif mins > 0: + return '%dm%02ds' % (mins, int(secs)) + elif secs >= 1: + return '%.1fs' % secs + return '%.2fs' % secs + +def get_prf(tp, fp, fn, get_str=False): + """Get precision, recall, f1 from true pos, false pos, false neg.""" + if tp + fp == 0: + precision = 0 + else: + precision = float(tp) / (tp + fp) + if tp + fn == 0: + recall = 0 + else: + recall = float(tp) / (tp + fn) + if precision + recall == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) + if get_str: + return '\n'.join([ + 'Precision: %.2f%%' % (100.0 * precision), + 'Recall : %.2f%%' % (100.0 * recall), + 'F1 : %.2f%%' % (100.0 * f1)]) return precision, recall, f1 \ No newline at end of file diff --git a/NAACL/vocabulary.py b/NAACL/vocabulary.py index 1f493da..8f4ed23 100644 --- a/NAACL/vocabulary.py +++ b/NAACL/vocabulary.py @@ -1,93 +1,93 @@ -import collections -UNK_TOKEN = '' -UNK_INDEX = 0 - -class Vocabulary(object): - def __init__(self, unk_threshold=0): - ''' - - :param unk_threshold: words with <= this many counts will be considered . - ''' - self.unk_threshold = unk_threshold - self.counts = collections.Counter() - self.word2index = {UNK_TOKEN: UNK_INDEX} - self.word_list = [UNK_TOKEN] - - def add_word(self, word, count=1): - ''' - Add a word (may still map to UNK if it doesn't pass unk_threshold). - :param word: - :param count: - :return: - ''' - self.counts[word] += count - if word not in self.word2index and self.counts[word] > self.unk_threshold: - index = len(self.word_list) - self.word2index[word] = index - self.word_list.append(word) - - def add_word(self, words): - for w in words: - self.add_word(w) - - def add_sentence(self, sentence): - self.add_word(sentence.split(' ')) - - def add_sentences(self, sentences): - for s in sentences: - self.add_sentences(s) - - def add_word_hard(self, word): - ''' - Add word, make sure it is not UNK. - :param word: - :return: - ''' - self.add_word(word, count=(self.unk_threshold+1)) - - def get_word(self, index): - return self.word_list[index] - - def get_index(self, word): - if word in self.word2index: - return self.word2index[word] - return UNK_INDEX - - def indexify_sentence(self, sentence): - return [self.get_index(w) for w in sentence.split(' ')] - - def indexify_list(self, elems): - return [self.get_index(w) for w in elems] - - def recover_sentenc(self, indices): - return ' '.join(self.get_word(i) for i in indices) - - def has_word(self, word): - return word in self.word2index - - def __contains__(self, word): - return word in self.word2index - - def size(self): - return len(self.word2index) - - def __len__(self): - return self.size() - def __iter__(self): - return iter(self.word_list) - - def save(self, filename): - '''Save word list.''' - with open(filename, 'w') as f: - for w in self.word_list: - print(w, file=f) - - @classmethod - def load(cls, filename): - '''Load word list (does not load counts).''' - vocab = cls() - with open(filename) as f: - for line in f: - w = line.strip('\n') - vocab.add_word_hard(w) +import collections +UNK_TOKEN = '' +UNK_INDEX = 0 + +class Vocabulary(object): + def __init__(self, unk_threshold=0): + ''' + + :param unk_threshold: words with <= this many counts will be considered . + ''' + self.unk_threshold = unk_threshold + self.counts = collections.Counter() + self.word2index = {UNK_TOKEN: UNK_INDEX} + self.word_list = [UNK_TOKEN] + + def add_word(self, word, count=1): + ''' + Add a word (may still map to UNK if it doesn't pass unk_threshold). + :param word: + :param count: + :return: + ''' + self.counts[word] += count + if word not in self.word2index and self.counts[word] > self.unk_threshold: + index = len(self.word_list) + self.word2index[word] = index + self.word_list.append(word) + + def add_word(self, words): + for w in words: + self.add_word(w) + + def add_sentence(self, sentence): + self.add_word(sentence.split(' ')) + + def add_sentences(self, sentences): + for s in sentences: + self.add_sentences(s) + + def add_word_hard(self, word): + ''' + Add word, make sure it is not UNK. + :param word: + :return: + ''' + self.add_word(word, count=(self.unk_threshold+1)) + + def get_word(self, index): + return self.word_list[index] + + def get_index(self, word): + if word in self.word2index: + return self.word2index[word] + return UNK_INDEX + + def indexify_sentence(self, sentence): + return [self.get_index(w) for w in sentence.split(' ')] + + def indexify_list(self, elems): + return [self.get_index(w) for w in elems] + + def recover_sentenc(self, indices): + return ' '.join(self.get_word(i) for i in indices) + + def has_word(self, word): + return word in self.word2index + + def __contains__(self, word): + return word in self.word2index + + def size(self): + return len(self.word2index) + + def __len__(self): + return self.size() + def __iter__(self): + return iter(self.word_list) + + def save(self, filename): + '''Save word list.''' + with open(filename, 'w') as f: + for w in self.word_list: + print(w, file=f) + + @classmethod + def load(cls, filename): + '''Load word list (does not load counts).''' + vocab = cls() + with open(filename) as f: + for line in f: + w = line.strip('\n') + vocab.add_word_hard(w) return vocab \ No newline at end of file diff --git a/maths.md b/maths.md new file mode 100644 index 0000000..608a8c6 --- /dev/null +++ b/maths.md @@ -0,0 +1,15 @@ +* In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden +* representation :math:`h_j` and a memory cell :math:`c_j`. The unit +* :math:`j` takes the input vector :math:`x_j` and the hidden +* representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as +* input, then update its new hidden representation :math:`h_j` and memory +* cell :math:`c_j` by: +* +### Tree-LSTM math:: +$$ + i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ + f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ + o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ + u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ + h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\ +$$ \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..6a0c9b3 --- /dev/null +++ b/run.sh @@ -0,0 +1,36 @@ +#!/bin/bash +set -eu -o pipefail +if [ $# -eq 0 ] +then + echo "Usage: $0 " 2>&1 + exit 1 +fi +out_dir="$1" +text_level="$2" +in_dir="$3" +shift; +shift; +shift; +flags="$@" +time python3 -m NAACL.backoffnet --ds-train-dev-file ${in_dir}/${text_level}/ds_train_dev.txt --jax-dev-test-file ${in_dir}/${text_level}/jax_dev_test.txt --text-level ${text_level} -o ${out_dir} --try-all-checkpoints ${flags} +dev_preds=`ls ${out_dir}/pred_dev*.tsv` +test_preds=`ls ${out_dir}/pred_test*.tsv` + +for fn in ${dev_preds} ${test_preds} +do + echo "$fn" + python3 -m NAACL.prune_pred_gv_map entity_tokens_eval/entities0.tsv entity_tokens_eval/tokens.json ${fn} ${fn}.pruned +# python3 -m machinereading.evaluation.eval ${fn} +done + +for fn in ${dev_preds} +do + echo "$fn" +# python3 -m machinereading.evaluation.eval ${fn}.pruned +done + +for fn in ${test_preds} +do + echo "$fn" +# python3 -m machinereading.evaluation.eval --test ${fn}.pruned +done \ No newline at end of file