diff --git a/.idea/GNNs.iml b/.idea/GNNs.iml
index 2cbb52b..7f807f0 100644
--- a/.idea/GNNs.iml
+++ b/.idea/GNNs.iml
@@ -2,7 +2,7 @@
-
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index b534b89..5417b68 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/Captum_BERT.py b/Captum_BERT.py
index 6c8c231..c495296 100644
--- a/Captum_BERT.py
+++ b/Captum_BERT.py
@@ -15,143 +15,143 @@
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, InterpretableEmbeddingBase
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
-from transformers import AutoTokenizer, AutoModelForMaskedLM
-tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')
-model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
+# from transformers import AutoTokenizer, AutoModelForMaskedLM
+# tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased')
+# model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
# print(torch.__version__) # hx_pc -> 1.6.0 + cu101
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
#
# model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
-# device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-#
-# model_path: str = 'saved_models/'
-# # load model
-# model: nn.Module = BertForQuestionAnswering.from_pretrained(model_path)
-# model.to(device)
-# model.eval()
-# model.zero_grad()
-#
-# # load tokenizer
-# tokenizer: PreTrainedTokenizerBase = BertTokenizer.from_pretrained(model_path)
-#
-# def predict(inputs: list, token_type_ids: list=None, position_ids: list=None, attention_mask: any=None) -> nn.Module:
-# return model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
-#
-# def squad_pos_forward_func(inputs:list, token_type_ids:list=None, position_ids:list=None, attention_mask: any=None,\
-# position:int=0) -> torch.Tensor:
-# pred: torch.Tensor = predict(inputs,
-# token_type_ids=token_type_ids,
-# position_ids=position_ids,
-# attention_mask=attention_mask)
-# pred = pred[position]
-# return pred.max(1).values
-#
-# # Optional[int]
-# ref_token_id = tokenizer.pad_token_id
-#
-# # Optional[int]
-# sep_token_id = tokenizer.sep_token_id
-#
-# # Optional[int]
-# cls_token_id = tokenizer.cls_token_id
-#
-# interpretable_embedding: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
-# 'bert.embeddings')
-# interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
-# 'bert.embeddings.word_embeddings')
-# interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
-# 'bert.embeddings.token_type_embeddings')
-# interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
-# 'bert.embeddings.position_embeddings')
-#
-#
-#
-#
-# def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \
-# cls_token_id: int | str) \
-# -> (torch.Tensor, torch.Tensor, int):
-# question_ids: list = tokenizer.encode(question, add_special_tokens=False)
-# text_ids: list = tokenizer.encode(text, add_special_tokens=False)
-#
-# input_ids: list = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]
-#
-# ref_input_ids: list = [cls_token_id] + [ref_token_id] + len(question_ids) + [sep_token_id] + \
-# [ref_token_id] * len(text_ids) + [sep_token_id]
-#
-# return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)
-#
-# def construct_input_ref_token_type_pair(input_ids: torch.Tensor, sep_ind:int = 0) -> (torch.Tensor, torch.Tensor):
-# seq_len: int = input_ids.size(1)
-# token_type_ids: torch.Tensor = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
-# ref_token_type_ids: torch.Tensor = torch.zeros_like(token_type_ids, device=device) # * -1
-# return token_type_ids, ref_token_type_ids
-#
-# def construct_input_ref_pos_id_pair(input_ids: torch.Tensor) -> (torch.Tensor, torch.Tensor):
-# seq_length: int = input_ids.size(1)
-# position_ids: torch.Tensor = torch.arange(seq_length, dtype=torch.long, device=device)
-# ref_position_ids: torch.Tensor = torch.zeros(seq_length, dtype=torch.long, device=device)
-#
-# position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
-# ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
-#
-# return position_ids, ref_position_ids
-#
-# def construct_attention_mask(input_ids: torch.Tensor) -> torch.Tensor:
-# return torch.ones_like(input_ids)
-#
-# def construct_bert_sub_embedding(input_ids: any, ref_input_ids: any, \
-# torken_type_ids: any, ref_token_type_ids: any, \
-# position_ids: any, ref_position_ids: any) \
-# -> ((torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor)):
-# input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(input_ids)
-# ref_input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(ref_input_ids)
-#
-# input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(torken_type_ids)
-# ref_input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids)
-#
-# input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(position_ids)
-# ref_input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(ref_position_ids)
-# return (input_embeddings, ref_input_embeddings), (input_embeddings_token_type, ref_input_embeddings_token_type), \
-# (input_embeddings_position_ids, ref_input_embeddings_position_ids)
-#
-# def construct_whole_bert_embeddings(input_ids: any, ref_input_ids: any, \
-# token_type_ids: any=None, ref_token_type_ids: any=None, \
-# position_ids: any=None, ref_position_ids:any=None)\
-# -> (torch.Tensor, torch.Tensor):
-# input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(input_ids, \
-# token_type_ids=token_type_ids, \
-# position_ids=position_ids)
-# ref_input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(ref_input_ids, \
-# token_type_ids=token_type_ids,\
-# position_ids=position_ids)
-# return input_embeddings, ref_input_embeddings
-#
-# question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds."
-#
-# input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
-# token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
-# position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
-# attention_mask = construct_attention_mask(input_ids)
-#
-# indices = input_ids[0].detach().tolist()
-# all_tokens = tokenizer.convert_ids_to_tokens(indices)
-#
-# ground_truth = 'to include, empower and support humans of all kinds'
-#
-# ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
-# ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
-# ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1
-#
-# start_scores, end_scores = predict(input_ids, \
-# token_type_ids=token_type_ids, \
-# position_ids=position_ids, \
-# attention_mask=attention_mask)
-#
-#
-# print('Question: ', question)
-# print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
-#
-#
-#
+device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+model_path: str = 'saved_models/'
+# load model
+model: nn.Module = BertForQuestionAnswering.from_pretrained(model_path)
+model.to(device)
+model.eval()
+model.zero_grad()
+
+# load tokenizer
+tokenizer: PreTrainedTokenizerBase = BertTokenizer.from_pretrained(model_path)
+
+def predict(inputs: list, token_type_ids: list=None, position_ids: list=None, attention_mask: any=None) -> nn.Module:
+ return model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
+
+def squad_pos_forward_func(inputs:list, token_type_ids:list=None, position_ids:list=None, attention_mask: any=None,\
+ position:int=0) -> torch.Tensor:
+ pred: torch.Tensor = predict(inputs,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask)
+ pred = pred[position]
+ return pred.max(1).values
+
+# Optional[int]
+ref_token_id = tokenizer.pad_token_id
+
+# Optional[int]
+sep_token_id = tokenizer.sep_token_id
+
+# Optional[int]
+cls_token_id = tokenizer.cls_token_id
+
+interpretable_embedding: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
+ 'bert.embeddings')
+interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
+ 'bert.embeddings.word_embeddings')
+interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
+ 'bert.embeddings.token_type_embeddings')
+interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \
+ 'bert.embeddings.position_embeddings')
+
+
+
+
+def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \
+ cls_token_id: int | str) \
+ -> (torch.Tensor, torch.Tensor, int):
+ question_ids: list = tokenizer.encode(question, add_special_tokens=False)
+ text_ids: list = tokenizer.encode(text, add_special_tokens=False)
+
+ input_ids: list = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id]
+
+ ref_input_ids: list = [cls_token_id] + [ref_token_id] + len(question_ids) + [sep_token_id] + \
+ [ref_token_id] * len(text_ids) + [sep_token_id]
+
+ return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids)
+
+def construct_input_ref_token_type_pair(input_ids: torch.Tensor, sep_ind:int = 0) -> (torch.Tensor, torch.Tensor):
+ seq_len: int = input_ids.size(1)
+ token_type_ids: torch.Tensor = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
+ ref_token_type_ids: torch.Tensor = torch.zeros_like(token_type_ids, device=device) # * -1
+ return token_type_ids, ref_token_type_ids
+
+def construct_input_ref_pos_id_pair(input_ids: torch.Tensor) -> (torch.Tensor, torch.Tensor):
+ seq_length: int = input_ids.size(1)
+ position_ids: torch.Tensor = torch.arange(seq_length, dtype=torch.long, device=device)
+ ref_position_ids: torch.Tensor = torch.zeros(seq_length, dtype=torch.long, device=device)
+
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
+
+ return position_ids, ref_position_ids
+
+def construct_attention_mask(input_ids: torch.Tensor) -> torch.Tensor:
+ return torch.ones_like(input_ids)
+
+def construct_bert_sub_embedding(input_ids: any, ref_input_ids: any, \
+ torken_type_ids: any, ref_token_type_ids: any, \
+ position_ids: any, ref_position_ids: any) \
+ -> ((torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor)):
+ input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(input_ids)
+ ref_input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(ref_input_ids)
+
+ input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(torken_type_ids)
+ ref_input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids)
+
+ input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(position_ids)
+ ref_input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(ref_position_ids)
+ return (input_embeddings, ref_input_embeddings), (input_embeddings_token_type, ref_input_embeddings_token_type), \
+ (input_embeddings_position_ids, ref_input_embeddings_position_ids)
+
+def construct_whole_bert_embeddings(input_ids: any, ref_input_ids: any, \
+ token_type_ids: any=None, ref_token_type_ids: any=None, \
+ position_ids: any=None, ref_position_ids:any=None)\
+ -> (torch.Tensor, torch.Tensor):
+ input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(input_ids, \
+ token_type_ids=token_type_ids, \
+ position_ids=position_ids)
+ ref_input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(ref_input_ids, \
+ token_type_ids=token_type_ids,\
+ position_ids=position_ids)
+ return input_embeddings, ref_input_embeddings
+
+question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds."
+
+input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id)
+token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
+position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
+attention_mask = construct_attention_mask(input_ids)
+
+indices = input_ids[0].detach().tolist()
+all_tokens = tokenizer.convert_ids_to_tokens(indices)
+
+ground_truth = 'to include, empower and support humans of all kinds'
+
+ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False)
+ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
+ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1
+
+start_scores, end_scores = predict(input_ids, \
+ token_type_ids=token_type_ids, \
+ position_ids=position_ids, \
+ attention_mask=attention_mask)
+
+
+print('Question: ', question)
+print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))
+
+
+
diff --git a/NAACL/__init__.py b/NAACL/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/NAACL/backoffnet.py b/NAACL/backoffnet.py
new file mode 100644
index 0000000..18d16f2
--- /dev/null
+++ b/NAACL/backoffnet.py
@@ -0,0 +1,26 @@
+import argparse
+import collections
+import glob
+import json
+import math
+import numpy as np
+import random
+from ordered_set import OrderedSet
+import os
+import pickle
+import shutil
+from sklearn.metrics import average_precision_score
+import sys
+import termcolor
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.utils.rnn as rnn
+import torch.optim as optim
+from tqdm import tqdm
+
+from NAACL import vocabulary
+from NAACL import settings
+from NAACL import util
+
diff --git a/NAACL/ensemble.py b/NAACL/ensemble.py
new file mode 100644
index 0000000..89e3133
--- /dev/null
+++ b/NAACL/ensemble.py
@@ -0,0 +1,86 @@
+'''Ensemble some predictions. '''
+import argparse
+import collections
+import math
+from scipy.special import logsumexp
+import sys
+
+MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio']
+
+def parse_args(args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('mode', choices=MODES)
+ parser.add_argument('files', nargs='+')
+ parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')],
+ help='Comma-separated lit of multiplizer per file')
+ parser.add_argument('--out-file', '-o', default=None, help='Where to write all output')
+
+ if len(sys.argv) == 1:
+ parser.print_help()
+ sys.exit(1)
+ return parser.parse_args(args)
+
+def read_preds(fn):
+ preds = []
+ with open(fn) as f:
+ for line in f:
+ idx, pmid, drug, gene, variant, prob = line.strip().split('\t')
+ prob = float(prob)
+ preds.append((pmid, drug, gene, variant, prob))
+
+ return preds
+
+def main(OPTS):
+ preds_all = [read_preds(fn) for fn in OPTS.files]
+ groups = collections.defaultdict(list)
+ for i, preds in enumerate(preds_all):
+ if OPTS.weights:
+ weight = OPTS.weights[i]
+ else:
+ weight = 1.0
+ for pmid, drug, gene, variant, prob in preds:
+ groups[(pmid, drug, gene, variant)].append(weight * prob)
+
+ results = []
+ for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()):
+ if OPTS.mode == 'mean':
+ prob = sum(prob_list) / len(prob_list)
+ elif OPTS.mode == 'max':
+ prob = max(prob_list)
+ elif OPTS.mode == 'logsumexp':
+ prob = logsumexp(prob_list)
+ elif OPTS.mode == 'noisy_or':
+ prob_no_rel = 1.0
+ for p in prob_list:
+ prob_no_rel *= 1.0 - p
+ prob =1.0 - prob_no_rel
+ elif OPTS.mode == 'log_noisy_or':
+ log_prob_no_rel = 0.0
+ for p in prob_list:
+ if p < 1.0:
+ log_prob_no_rel += math.log(1.0 - p)
+ else:
+ log_prob_no_rel -= 1000000
+ prob = -log_prob_no_rel
+ elif OPTS.mode == 'odds_ratio':
+ cur_log_odds = 0.0
+ for p in prob_list:
+ cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000)
+ prob = cur_log_odds
+ else:
+ raise ValueError(OPTS.mode)
+ results.append((i, pmid, drug, gene, variant, prob))
+
+ with open(OPTS.out_file, 'w') as f:
+ for item in results:
+ f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item))
+
+if __name__ == '__main__':
+ OPTS = parse_args(sys.argv[1:])
+ main(OPTS)
+
+
+
+
+
+
diff --git a/NAACL/prune_pred_gv_map.py b/NAACL/prune_pred_gv_map.py
new file mode 100644
index 0000000..5b15f6c
--- /dev/null
+++ b/NAACL/prune_pred_gv_map.py
@@ -0,0 +1,51 @@
+"""Prune model predictions with rule-based G-V linker."""
+import argparse
+import collections
+import os
+import sys
+
+from NAACL import settings
+
+OPTS = None
+
+GV_MAP_FILE = 'gene_var/gene_to_var.tsv'
+
+def prep_gv_mapping():
+ var_to_gene= {}
+ gene_to_var= collections.defaultdict(set)
+ pmid_to_gv = collections.defaultdict(set)
+ pmid_gv_map = {}
+ with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f:
+ for line in f:
+ pmid, variant, gene = line.strip().strip()
+ gene = gene.lower()
+ var_to_gene[(pmid, variant)] = gene
+ gene_to_var[(pmid, gene)].add(variant)
+ pmid_to_gv[pmid].add((gene, variant))
+
+ return var_to_gene, gene_to_var, pmid_to_gv
+
+def parse_args(args):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('pred_file')
+ parser.add_argument('out_file')
+ if len(args) == 0:
+ parser.print_help()
+ sys.exit(1)
+ return parser.parse_args(args)
+
+def main(OPTS):
+ var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping()
+ with open(OPTS.pred_file) as fin:
+ with open(OPTS.out_file) as fout:
+ for line in fin:
+ idx, pmid, d, g, v, prob = line.strip().split('\t')
+ if(pmid, v) not in var_to_gene:
+ continue
+ g_linked = var_to_gene[(pmid, v)]
+ if g_linked == g:
+ fout.write(line)
+
+if __name__ == '__main__':
+ OPTS = parse_args(sys.argv[1:])
+ main(OPTS)
\ No newline at end of file
diff --git a/NAACL/settings.py b/NAACL/settings.py
new file mode 100644
index 0000000..54c9bba
--- /dev/null
+++ b/NAACL/settings.py
@@ -0,0 +1 @@
+DATA_DIR = 'data'
diff --git a/NAACL/util.py b/NAACL/util.py
new file mode 100644
index 0000000..a98022c
--- /dev/null
+++ b/NAACL/util.py
@@ -0,0 +1,41 @@
+SECS_PER_MIN = 60
+SECS_PER_HOUR = SECS_PER_MIN * 60
+SECS_PER_DAY = SECS_PER_HOUR * 24
+
+def secs_to_str(secs):
+ days = int(secs) // SECS_PER_DAY
+ secs -= days * SECS_PER_DAY
+ hours = int(secs) // SECS_PER_HOUR
+ secs -= hours * SECS_PER_HOUR
+ mins = int(secs) // SECS_PER_MIN
+ secs -= mins * SECS_PER_MIN
+ if days > 0:
+ return '%dd%02dh%02dm' % (days, hours, mins)
+ elif hours > 0:
+ return '%dh%02dm%02ds' % (hours, mins, int(secs))
+ elif mins > 0:
+ return '%dm%02ds' % (mins, int(secs))
+ elif secs >= 1:
+ return '%.1fs' % secs
+ return '%.2fs' % secs
+
+def get_prf(tp, fp, fn, get_str=False):
+ """Get precision, recall, f1 from true pos, false pos, false neg."""
+ if tp + fp == 0:
+ precision = 0
+ else:
+ precision = float(tp) / (tp + fp)
+ if tp + fn == 0:
+ recall = 0
+ else:
+ recall = float(tp) / (tp + fn)
+ if precision + recall == 0:
+ f1 = 0
+ else:
+ f1 = 2 * precision * recall / (precision + recall)
+ if get_str:
+ return '\n'.join([
+ 'Precision: %.2f%%' % (100.0 * precision),
+ 'Recall : %.2f%%' % (100.0 * recall),
+ 'F1 : %.2f%%' % (100.0 * f1)])
+ return precision, recall, f1
\ No newline at end of file
diff --git a/NAACL/vocabulary.py b/NAACL/vocabulary.py
new file mode 100644
index 0000000..1f493da
--- /dev/null
+++ b/NAACL/vocabulary.py
@@ -0,0 +1,93 @@
+import collections
+UNK_TOKEN = ''
+UNK_INDEX = 0
+
+class Vocabulary(object):
+ def __init__(self, unk_threshold=0):
+ '''
+
+ :param unk_threshold: words with <= this many counts will be considered .
+ '''
+ self.unk_threshold = unk_threshold
+ self.counts = collections.Counter()
+ self.word2index = {UNK_TOKEN: UNK_INDEX}
+ self.word_list = [UNK_TOKEN]
+
+ def add_word(self, word, count=1):
+ '''
+ Add a word (may still map to UNK if it doesn't pass unk_threshold).
+ :param word:
+ :param count:
+ :return:
+ '''
+ self.counts[word] += count
+ if word not in self.word2index and self.counts[word] > self.unk_threshold:
+ index = len(self.word_list)
+ self.word2index[word] = index
+ self.word_list.append(word)
+
+ def add_word(self, words):
+ for w in words:
+ self.add_word(w)
+
+ def add_sentence(self, sentence):
+ self.add_word(sentence.split(' '))
+
+ def add_sentences(self, sentences):
+ for s in sentences:
+ self.add_sentences(s)
+
+ def add_word_hard(self, word):
+ '''
+ Add word, make sure it is not UNK.
+ :param word:
+ :return:
+ '''
+ self.add_word(word, count=(self.unk_threshold+1))
+
+ def get_word(self, index):
+ return self.word_list[index]
+
+ def get_index(self, word):
+ if word in self.word2index:
+ return self.word2index[word]
+ return UNK_INDEX
+
+ def indexify_sentence(self, sentence):
+ return [self.get_index(w) for w in sentence.split(' ')]
+
+ def indexify_list(self, elems):
+ return [self.get_index(w) for w in elems]
+
+ def recover_sentenc(self, indices):
+ return ' '.join(self.get_word(i) for i in indices)
+
+ def has_word(self, word):
+ return word in self.word2index
+
+ def __contains__(self, word):
+ return word in self.word2index
+
+ def size(self):
+ return len(self.word2index)
+
+ def __len__(self):
+ return self.size()
+ def __iter__(self):
+ return iter(self.word_list)
+
+ def save(self, filename):
+ '''Save word list.'''
+ with open(filename, 'w') as f:
+ for w in self.word_list:
+ print(w, file=f)
+
+ @classmethod
+ def load(cls, filename):
+ '''Load word list (does not load counts).'''
+ vocab = cls()
+ with open(filename) as f:
+ for line in f:
+ w = line.strip('\n')
+ vocab.add_word_hard(w)
+ return vocab
\ No newline at end of file