-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
acproejct
committed
Nov 25, 2020
1 parent
5fb01eb
commit fe5c0d2
Showing
10 changed files
with
434 additions
and
136 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import argparse | ||
import collections | ||
import glob | ||
import json | ||
import math | ||
import numpy as np | ||
import random | ||
from ordered_set import OrderedSet | ||
import os | ||
import pickle | ||
import shutil | ||
from sklearn.metrics import average_precision_score | ||
import sys | ||
import termcolor | ||
import time | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.nn.utils.rnn as rnn | ||
import torch.optim as optim | ||
from tqdm import tqdm | ||
|
||
from NAACL import vocabulary | ||
from NAACL import settings | ||
from NAACL import util | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
'''Ensemble some predictions. ''' | ||
import argparse | ||
import collections | ||
import math | ||
from scipy.special import logsumexp | ||
import sys | ||
|
||
MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio'] | ||
|
||
def parse_args(args): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('mode', choices=MODES) | ||
parser.add_argument('files', nargs='+') | ||
parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')], | ||
help='Comma-separated lit of multiplizer per file') | ||
parser.add_argument('--out-file', '-o', default=None, help='Where to write all output') | ||
|
||
if len(sys.argv) == 1: | ||
parser.print_help() | ||
sys.exit(1) | ||
return parser.parse_args(args) | ||
|
||
def read_preds(fn): | ||
preds = [] | ||
with open(fn) as f: | ||
for line in f: | ||
idx, pmid, drug, gene, variant, prob = line.strip().split('\t') | ||
prob = float(prob) | ||
preds.append((pmid, drug, gene, variant, prob)) | ||
|
||
return preds | ||
|
||
def main(OPTS): | ||
preds_all = [read_preds(fn) for fn in OPTS.files] | ||
groups = collections.defaultdict(list) | ||
for i, preds in enumerate(preds_all): | ||
if OPTS.weights: | ||
weight = OPTS.weights[i] | ||
else: | ||
weight = 1.0 | ||
for pmid, drug, gene, variant, prob in preds: | ||
groups[(pmid, drug, gene, variant)].append(weight * prob) | ||
|
||
results = [] | ||
for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()): | ||
if OPTS.mode == 'mean': | ||
prob = sum(prob_list) / len(prob_list) | ||
elif OPTS.mode == 'max': | ||
prob = max(prob_list) | ||
elif OPTS.mode == 'logsumexp': | ||
prob = logsumexp(prob_list) | ||
elif OPTS.mode == 'noisy_or': | ||
prob_no_rel = 1.0 | ||
for p in prob_list: | ||
prob_no_rel *= 1.0 - p | ||
prob =1.0 - prob_no_rel | ||
elif OPTS.mode == 'log_noisy_or': | ||
log_prob_no_rel = 0.0 | ||
for p in prob_list: | ||
if p < 1.0: | ||
log_prob_no_rel += math.log(1.0 - p) | ||
else: | ||
log_prob_no_rel -= 1000000 | ||
prob = -log_prob_no_rel | ||
elif OPTS.mode == 'odds_ratio': | ||
cur_log_odds = 0.0 | ||
for p in prob_list: | ||
cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000) | ||
prob = cur_log_odds | ||
else: | ||
raise ValueError(OPTS.mode) | ||
results.append((i, pmid, drug, gene, variant, prob)) | ||
|
||
with open(OPTS.out_file, 'w') as f: | ||
for item in results: | ||
f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item)) | ||
|
||
if __name__ == '__main__': | ||
OPTS = parse_args(sys.argv[1:]) | ||
main(OPTS) | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""Prune model predictions with rule-based G-V linker.""" | ||
import argparse | ||
import collections | ||
import os | ||
import sys | ||
|
||
from NAACL import settings | ||
|
||
OPTS = None | ||
|
||
GV_MAP_FILE = 'gene_var/gene_to_var.tsv' | ||
|
||
def prep_gv_mapping(): | ||
var_to_gene= {} | ||
gene_to_var= collections.defaultdict(set) | ||
pmid_to_gv = collections.defaultdict(set) | ||
pmid_gv_map = {} | ||
with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f: | ||
for line in f: | ||
pmid, variant, gene = line.strip().strip() | ||
gene = gene.lower() | ||
var_to_gene[(pmid, variant)] = gene | ||
gene_to_var[(pmid, gene)].add(variant) | ||
pmid_to_gv[pmid].add((gene, variant)) | ||
|
||
return var_to_gene, gene_to_var, pmid_to_gv | ||
|
||
def parse_args(args): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('pred_file') | ||
parser.add_argument('out_file') | ||
if len(args) == 0: | ||
parser.print_help() | ||
sys.exit(1) | ||
return parser.parse_args(args) | ||
|
||
def main(OPTS): | ||
var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping() | ||
with open(OPTS.pred_file) as fin: | ||
with open(OPTS.out_file) as fout: | ||
for line in fin: | ||
idx, pmid, d, g, v, prob = line.strip().split('\t') | ||
if(pmid, v) not in var_to_gene: | ||
continue | ||
g_linked = var_to_gene[(pmid, v)] | ||
if g_linked == g: | ||
fout.write(line) | ||
|
||
if __name__ == '__main__': | ||
OPTS = parse_args(sys.argv[1:]) | ||
main(OPTS) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
DATA_DIR = 'data' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
SECS_PER_MIN = 60 | ||
SECS_PER_HOUR = SECS_PER_MIN * 60 | ||
SECS_PER_DAY = SECS_PER_HOUR * 24 | ||
|
||
def secs_to_str(secs): | ||
days = int(secs) // SECS_PER_DAY | ||
secs -= days * SECS_PER_DAY | ||
hours = int(secs) // SECS_PER_HOUR | ||
secs -= hours * SECS_PER_HOUR | ||
mins = int(secs) // SECS_PER_MIN | ||
secs -= mins * SECS_PER_MIN | ||
if days > 0: | ||
return '%dd%02dh%02dm' % (days, hours, mins) | ||
elif hours > 0: | ||
return '%dh%02dm%02ds' % (hours, mins, int(secs)) | ||
elif mins > 0: | ||
return '%dm%02ds' % (mins, int(secs)) | ||
elif secs >= 1: | ||
return '%.1fs' % secs | ||
return '%.2fs' % secs | ||
|
||
def get_prf(tp, fp, fn, get_str=False): | ||
"""Get precision, recall, f1 from true pos, false pos, false neg.""" | ||
if tp + fp == 0: | ||
precision = 0 | ||
else: | ||
precision = float(tp) / (tp + fp) | ||
if tp + fn == 0: | ||
recall = 0 | ||
else: | ||
recall = float(tp) / (tp + fn) | ||
if precision + recall == 0: | ||
f1 = 0 | ||
else: | ||
f1 = 2 * precision * recall / (precision + recall) | ||
if get_str: | ||
return '\n'.join([ | ||
'Precision: %.2f%%' % (100.0 * precision), | ||
'Recall : %.2f%%' % (100.0 * recall), | ||
'F1 : %.2f%%' % (100.0 * f1)]) | ||
return precision, recall, f1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import collections | ||
UNK_TOKEN = '<UNK>' | ||
UNK_INDEX = 0 | ||
|
||
class Vocabulary(object): | ||
def __init__(self, unk_threshold=0): | ||
''' | ||
:param unk_threshold: words with <= this many counts will be considered <UNK>. | ||
''' | ||
self.unk_threshold = unk_threshold | ||
self.counts = collections.Counter() | ||
self.word2index = {UNK_TOKEN: UNK_INDEX} | ||
self.word_list = [UNK_TOKEN] | ||
|
||
def add_word(self, word, count=1): | ||
''' | ||
Add a word (may still map to UNK if it doesn't pass unk_threshold). | ||
:param word: | ||
:param count: | ||
:return: | ||
''' | ||
self.counts[word] += count | ||
if word not in self.word2index and self.counts[word] > self.unk_threshold: | ||
index = len(self.word_list) | ||
self.word2index[word] = index | ||
self.word_list.append(word) | ||
|
||
def add_word(self, words): | ||
for w in words: | ||
self.add_word(w) | ||
|
||
def add_sentence(self, sentence): | ||
self.add_word(sentence.split(' ')) | ||
|
||
def add_sentences(self, sentences): | ||
for s in sentences: | ||
self.add_sentences(s) | ||
|
||
def add_word_hard(self, word): | ||
''' | ||
Add word, make sure it is not UNK. | ||
:param word: | ||
:return: | ||
''' | ||
self.add_word(word, count=(self.unk_threshold+1)) | ||
|
||
def get_word(self, index): | ||
return self.word_list[index] | ||
|
||
def get_index(self, word): | ||
if word in self.word2index: | ||
return self.word2index[word] | ||
return UNK_INDEX | ||
|
||
def indexify_sentence(self, sentence): | ||
return [self.get_index(w) for w in sentence.split(' ')] | ||
|
||
def indexify_list(self, elems): | ||
return [self.get_index(w) for w in elems] | ||
|
||
def recover_sentenc(self, indices): | ||
return ' '.join(self.get_word(i) for i in indices) | ||
|
||
def has_word(self, word): | ||
return word in self.word2index | ||
|
||
def __contains__(self, word): | ||
return word in self.word2index | ||
|
||
def size(self): | ||
return len(self.word2index) | ||
|
||
def __len__(self): | ||
return self.size() | ||
def __iter__(self): | ||
return iter(self.word_list) | ||
|
||
def save(self, filename): | ||
'''Save word list.''' | ||
with open(filename, 'w') as f: | ||
for w in self.word_list: | ||
print(w, file=f) | ||
|
||
@classmethod | ||
def load(cls, filename): | ||
'''Load word list (does not load counts).''' | ||
vocab = cls() | ||
with open(filename) as f: | ||
for line in f: | ||
w = line.strip('\n') | ||
vocab.add_word_hard(w) | ||
return vocab |