Skip to content

Commit

Permalink
add NLP mode
Browse files Browse the repository at this point in the history
  • Loading branch information
acproejct committed Nov 25, 2020
1 parent 5fb01eb commit fe5c0d2
Show file tree
Hide file tree
Showing 10 changed files with 434 additions and 136 deletions.
2 changes: 1 addition & 1 deletion .idea/GNNs.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

268 changes: 134 additions & 134 deletions Captum_BERT.py

Large diffs are not rendered by default.

Empty file added NAACL/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions NAACL/backoffnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import argparse
import collections
import glob
import json
import math
import numpy as np
import random
from ordered_set import OrderedSet
import os
import pickle
import shutil
from sklearn.metrics import average_precision_score
import sys
import termcolor
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
import torch.optim as optim
from tqdm import tqdm

from NAACL import vocabulary
from NAACL import settings
from NAACL import util

86 changes: 86 additions & 0 deletions NAACL/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
'''Ensemble some predictions. '''
import argparse
import collections
import math
from scipy.special import logsumexp
import sys

MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio']

def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument('mode', choices=MODES)
parser.add_argument('files', nargs='+')
parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')],
help='Comma-separated lit of multiplizer per file')
parser.add_argument('--out-file', '-o', default=None, help='Where to write all output')

if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
return parser.parse_args(args)

def read_preds(fn):
preds = []
with open(fn) as f:
for line in f:
idx, pmid, drug, gene, variant, prob = line.strip().split('\t')
prob = float(prob)
preds.append((pmid, drug, gene, variant, prob))

return preds

def main(OPTS):
preds_all = [read_preds(fn) for fn in OPTS.files]
groups = collections.defaultdict(list)
for i, preds in enumerate(preds_all):
if OPTS.weights:
weight = OPTS.weights[i]
else:
weight = 1.0
for pmid, drug, gene, variant, prob in preds:
groups[(pmid, drug, gene, variant)].append(weight * prob)

results = []
for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()):
if OPTS.mode == 'mean':
prob = sum(prob_list) / len(prob_list)
elif OPTS.mode == 'max':
prob = max(prob_list)
elif OPTS.mode == 'logsumexp':
prob = logsumexp(prob_list)
elif OPTS.mode == 'noisy_or':
prob_no_rel = 1.0
for p in prob_list:
prob_no_rel *= 1.0 - p
prob =1.0 - prob_no_rel
elif OPTS.mode == 'log_noisy_or':
log_prob_no_rel = 0.0
for p in prob_list:
if p < 1.0:
log_prob_no_rel += math.log(1.0 - p)
else:
log_prob_no_rel -= 1000000
prob = -log_prob_no_rel
elif OPTS.mode == 'odds_ratio':
cur_log_odds = 0.0
for p in prob_list:
cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000)
prob = cur_log_odds
else:
raise ValueError(OPTS.mode)
results.append((i, pmid, drug, gene, variant, prob))

with open(OPTS.out_file, 'w') as f:
for item in results:
f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item))

if __name__ == '__main__':
OPTS = parse_args(sys.argv[1:])
main(OPTS)






51 changes: 51 additions & 0 deletions NAACL/prune_pred_gv_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Prune model predictions with rule-based G-V linker."""
import argparse
import collections
import os
import sys

from NAACL import settings

OPTS = None

GV_MAP_FILE = 'gene_var/gene_to_var.tsv'

def prep_gv_mapping():
var_to_gene= {}
gene_to_var= collections.defaultdict(set)
pmid_to_gv = collections.defaultdict(set)
pmid_gv_map = {}
with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f:
for line in f:
pmid, variant, gene = line.strip().strip()
gene = gene.lower()
var_to_gene[(pmid, variant)] = gene
gene_to_var[(pmid, gene)].add(variant)
pmid_to_gv[pmid].add((gene, variant))

return var_to_gene, gene_to_var, pmid_to_gv

def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument('pred_file')
parser.add_argument('out_file')
if len(args) == 0:
parser.print_help()
sys.exit(1)
return parser.parse_args(args)

def main(OPTS):
var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping()
with open(OPTS.pred_file) as fin:
with open(OPTS.out_file) as fout:
for line in fin:
idx, pmid, d, g, v, prob = line.strip().split('\t')
if(pmid, v) not in var_to_gene:
continue
g_linked = var_to_gene[(pmid, v)]
if g_linked == g:
fout.write(line)

if __name__ == '__main__':
OPTS = parse_args(sys.argv[1:])
main(OPTS)
1 change: 1 addition & 0 deletions NAACL/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DATA_DIR = 'data'
41 changes: 41 additions & 0 deletions NAACL/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
SECS_PER_MIN = 60
SECS_PER_HOUR = SECS_PER_MIN * 60
SECS_PER_DAY = SECS_PER_HOUR * 24

def secs_to_str(secs):
days = int(secs) // SECS_PER_DAY
secs -= days * SECS_PER_DAY
hours = int(secs) // SECS_PER_HOUR
secs -= hours * SECS_PER_HOUR
mins = int(secs) // SECS_PER_MIN
secs -= mins * SECS_PER_MIN
if days > 0:
return '%dd%02dh%02dm' % (days, hours, mins)
elif hours > 0:
return '%dh%02dm%02ds' % (hours, mins, int(secs))
elif mins > 0:
return '%dm%02ds' % (mins, int(secs))
elif secs >= 1:
return '%.1fs' % secs
return '%.2fs' % secs

def get_prf(tp, fp, fn, get_str=False):
"""Get precision, recall, f1 from true pos, false pos, false neg."""
if tp + fp == 0:
precision = 0
else:
precision = float(tp) / (tp + fp)
if tp + fn == 0:
recall = 0
else:
recall = float(tp) / (tp + fn)
if precision + recall == 0:
f1 = 0
else:
f1 = 2 * precision * recall / (precision + recall)
if get_str:
return '\n'.join([
'Precision: %.2f%%' % (100.0 * precision),
'Recall : %.2f%%' % (100.0 * recall),
'F1 : %.2f%%' % (100.0 * f1)])
return precision, recall, f1
93 changes: 93 additions & 0 deletions NAACL/vocabulary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import collections
UNK_TOKEN = '<UNK>'
UNK_INDEX = 0

class Vocabulary(object):
def __init__(self, unk_threshold=0):
'''
:param unk_threshold: words with <= this many counts will be considered <UNK>.
'''
self.unk_threshold = unk_threshold
self.counts = collections.Counter()
self.word2index = {UNK_TOKEN: UNK_INDEX}
self.word_list = [UNK_TOKEN]

def add_word(self, word, count=1):
'''
Add a word (may still map to UNK if it doesn't pass unk_threshold).
:param word:
:param count:
:return:
'''
self.counts[word] += count
if word not in self.word2index and self.counts[word] > self.unk_threshold:
index = len(self.word_list)
self.word2index[word] = index
self.word_list.append(word)

def add_word(self, words):
for w in words:
self.add_word(w)

def add_sentence(self, sentence):
self.add_word(sentence.split(' '))

def add_sentences(self, sentences):
for s in sentences:
self.add_sentences(s)

def add_word_hard(self, word):
'''
Add word, make sure it is not UNK.
:param word:
:return:
'''
self.add_word(word, count=(self.unk_threshold+1))

def get_word(self, index):
return self.word_list[index]

def get_index(self, word):
if word in self.word2index:
return self.word2index[word]
return UNK_INDEX

def indexify_sentence(self, sentence):
return [self.get_index(w) for w in sentence.split(' ')]

def indexify_list(self, elems):
return [self.get_index(w) for w in elems]

def recover_sentenc(self, indices):
return ' '.join(self.get_word(i) for i in indices)

def has_word(self, word):
return word in self.word2index

def __contains__(self, word):
return word in self.word2index

def size(self):
return len(self.word2index)

def __len__(self):
return self.size()
def __iter__(self):
return iter(self.word_list)

def save(self, filename):
'''Save word list.'''
with open(filename, 'w') as f:
for w in self.word_list:
print(w, file=f)

@classmethod
def load(cls, filename):
'''Load word list (does not load counts).'''
vocab = cls()
with open(filename) as f:
for line in f:
w = line.strip('\n')
vocab.add_word_hard(w)
return vocab

0 comments on commit fe5c0d2

Please sign in to comment.