diff --git a/.idea/GNNs.iml b/.idea/GNNs.iml index 2cbb52b..7f807f0 100644 --- a/.idea/GNNs.iml +++ b/.idea/GNNs.iml @@ -2,7 +2,7 @@ - + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index b534b89..5417b68 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/Captum_BERT.py b/Captum_BERT.py index 6c8c231..c495296 100644 --- a/Captum_BERT.py +++ b/Captum_BERT.py @@ -15,143 +15,143 @@ from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, InterpretableEmbeddingBase from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer -from transformers import AutoTokenizer, AutoModelForMaskedLM -tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased') -model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased') +# from transformers import AutoTokenizer, AutoModelForMaskedLM +# tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased') +# model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased') # print(torch.__version__) # hx_pc -> 1.6.0 + cu101 # tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # # model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") -# device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -# -# model_path: str = 'saved_models/' -# # load model -# model: nn.Module = BertForQuestionAnswering.from_pretrained(model_path) -# model.to(device) -# model.eval() -# model.zero_grad() -# -# # load tokenizer -# tokenizer: PreTrainedTokenizerBase = BertTokenizer.from_pretrained(model_path) -# -# def predict(inputs: list, token_type_ids: list=None, position_ids: list=None, attention_mask: any=None) -> nn.Module: -# return model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask) -# -# def squad_pos_forward_func(inputs:list, token_type_ids:list=None, position_ids:list=None, attention_mask: any=None,\ -# position:int=0) -> torch.Tensor: -# pred: torch.Tensor = predict(inputs, -# token_type_ids=token_type_ids, -# position_ids=position_ids, -# attention_mask=attention_mask) -# pred = pred[position] -# return pred.max(1).values -# -# # Optional[int] -# ref_token_id = tokenizer.pad_token_id -# -# # Optional[int] -# sep_token_id = tokenizer.sep_token_id -# -# # Optional[int] -# cls_token_id = tokenizer.cls_token_id -# -# interpretable_embedding: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ -# 'bert.embeddings') -# interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ -# 'bert.embeddings.word_embeddings') -# interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ -# 'bert.embeddings.token_type_embeddings') -# interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ -# 'bert.embeddings.position_embeddings') -# -# -# -# -# def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \ -# cls_token_id: int | str) \ -# -> (torch.Tensor, torch.Tensor, int): -# question_ids: list = tokenizer.encode(question, add_special_tokens=False) -# text_ids: list = tokenizer.encode(text, add_special_tokens=False) -# -# input_ids: list = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id] -# -# ref_input_ids: list = [cls_token_id] + [ref_token_id] + len(question_ids) + [sep_token_id] + \ -# [ref_token_id] * len(text_ids) + [sep_token_id] -# -# return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids) -# -# def construct_input_ref_token_type_pair(input_ids: torch.Tensor, sep_ind:int = 0) -> (torch.Tensor, torch.Tensor): -# seq_len: int = input_ids.size(1) -# token_type_ids: torch.Tensor = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device) -# ref_token_type_ids: torch.Tensor = torch.zeros_like(token_type_ids, device=device) # * -1 -# return token_type_ids, ref_token_type_ids -# -# def construct_input_ref_pos_id_pair(input_ids: torch.Tensor) -> (torch.Tensor, torch.Tensor): -# seq_length: int = input_ids.size(1) -# position_ids: torch.Tensor = torch.arange(seq_length, dtype=torch.long, device=device) -# ref_position_ids: torch.Tensor = torch.zeros(seq_length, dtype=torch.long, device=device) -# -# position_ids = position_ids.unsqueeze(0).expand_as(input_ids) -# ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids) -# -# return position_ids, ref_position_ids -# -# def construct_attention_mask(input_ids: torch.Tensor) -> torch.Tensor: -# return torch.ones_like(input_ids) -# -# def construct_bert_sub_embedding(input_ids: any, ref_input_ids: any, \ -# torken_type_ids: any, ref_token_type_ids: any, \ -# position_ids: any, ref_position_ids: any) \ -# -> ((torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor)): -# input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(input_ids) -# ref_input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(ref_input_ids) -# -# input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(torken_type_ids) -# ref_input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids) -# -# input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(position_ids) -# ref_input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(ref_position_ids) -# return (input_embeddings, ref_input_embeddings), (input_embeddings_token_type, ref_input_embeddings_token_type), \ -# (input_embeddings_position_ids, ref_input_embeddings_position_ids) -# -# def construct_whole_bert_embeddings(input_ids: any, ref_input_ids: any, \ -# token_type_ids: any=None, ref_token_type_ids: any=None, \ -# position_ids: any=None, ref_position_ids:any=None)\ -# -> (torch.Tensor, torch.Tensor): -# input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(input_ids, \ -# token_type_ids=token_type_ids, \ -# position_ids=position_ids) -# ref_input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(ref_input_ids, \ -# token_type_ids=token_type_ids,\ -# position_ids=position_ids) -# return input_embeddings, ref_input_embeddings -# -# question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds." -# -# input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id) -# token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id) -# position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids) -# attention_mask = construct_attention_mask(input_ids) -# -# indices = input_ids[0].detach().tolist() -# all_tokens = tokenizer.convert_ids_to_tokens(indices) -# -# ground_truth = 'to include, empower and support humans of all kinds' -# -# ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False) -# ground_truth_end_ind = indices.index(ground_truth_tokens[-1]) -# ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1 -# -# start_scores, end_scores = predict(input_ids, \ -# token_type_ids=token_type_ids, \ -# position_ids=position_ids, \ -# attention_mask=attention_mask) -# -# -# print('Question: ', question) -# print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])) -# -# -# +device: torch.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +model_path: str = 'saved_models/' +# load model +model: nn.Module = BertForQuestionAnswering.from_pretrained(model_path) +model.to(device) +model.eval() +model.zero_grad() + +# load tokenizer +tokenizer: PreTrainedTokenizerBase = BertTokenizer.from_pretrained(model_path) + +def predict(inputs: list, token_type_ids: list=None, position_ids: list=None, attention_mask: any=None) -> nn.Module: + return model(inputs, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask) + +def squad_pos_forward_func(inputs:list, token_type_ids:list=None, position_ids:list=None, attention_mask: any=None,\ + position:int=0) -> torch.Tensor: + pred: torch.Tensor = predict(inputs, + token_type_ids=token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + pred = pred[position] + return pred.max(1).values + +# Optional[int] +ref_token_id = tokenizer.pad_token_id + +# Optional[int] +sep_token_id = tokenizer.sep_token_id + +# Optional[int] +cls_token_id = tokenizer.cls_token_id + +interpretable_embedding: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ + 'bert.embeddings') +interpretable_embedding1: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ + 'bert.embeddings.word_embeddings') +interpretable_embedding2: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ + 'bert.embeddings.token_type_embeddings') +interpretable_embedding3: InterpretableEmbeddingBase = configure_interpretable_embedding_layer(model, \ + 'bert.embeddings.position_embeddings') + + + + +def construct_input_ref_pair(question: str, text: str, ref_token_id: int | str, sep_token_id: int | str, \ + cls_token_id: int | str) \ + -> (torch.Tensor, torch.Tensor, int): + question_ids: list = tokenizer.encode(question, add_special_tokens=False) + text_ids: list = tokenizer.encode(text, add_special_tokens=False) + + input_ids: list = [cls_token_id] + question_ids + [sep_token_id] + text_ids + [sep_token_id] + + ref_input_ids: list = [cls_token_id] + [ref_token_id] + len(question_ids) + [sep_token_id] + \ + [ref_token_id] * len(text_ids) + [sep_token_id] + + return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(question_ids) + +def construct_input_ref_token_type_pair(input_ids: torch.Tensor, sep_ind:int = 0) -> (torch.Tensor, torch.Tensor): + seq_len: int = input_ids.size(1) + token_type_ids: torch.Tensor = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device) + ref_token_type_ids: torch.Tensor = torch.zeros_like(token_type_ids, device=device) # * -1 + return token_type_ids, ref_token_type_ids + +def construct_input_ref_pos_id_pair(input_ids: torch.Tensor) -> (torch.Tensor, torch.Tensor): + seq_length: int = input_ids.size(1) + position_ids: torch.Tensor = torch.arange(seq_length, dtype=torch.long, device=device) + ref_position_ids: torch.Tensor = torch.zeros(seq_length, dtype=torch.long, device=device) + + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids) + + return position_ids, ref_position_ids + +def construct_attention_mask(input_ids: torch.Tensor) -> torch.Tensor: + return torch.ones_like(input_ids) + +def construct_bert_sub_embedding(input_ids: any, ref_input_ids: any, \ + torken_type_ids: any, ref_token_type_ids: any, \ + position_ids: any, ref_position_ids: any) \ + -> ((torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor),(torch.Tensor, torch.Tensor)): + input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(input_ids) + ref_input_embeddings: torch.Tensor = interpretable_embedding1.indices_to_embeddings(ref_input_ids) + + input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(torken_type_ids) + ref_input_embeddings_token_type: torch.Tensor = interpretable_embedding2.indices_to_embeddings(ref_token_type_ids) + + input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(position_ids) + ref_input_embeddings_position_ids: torch.Tensor = interpretable_embedding3.indices_to_embeddings(ref_position_ids) + return (input_embeddings, ref_input_embeddings), (input_embeddings_token_type, ref_input_embeddings_token_type), \ + (input_embeddings_position_ids, ref_input_embeddings_position_ids) + +def construct_whole_bert_embeddings(input_ids: any, ref_input_ids: any, \ + token_type_ids: any=None, ref_token_type_ids: any=None, \ + position_ids: any=None, ref_position_ids:any=None)\ + -> (torch.Tensor, torch.Tensor): + input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(input_ids, \ + token_type_ids=token_type_ids, \ + position_ids=position_ids) + ref_input_embeddings: torch.Tensor = interpretable_embedding.indices_to_embeddings(ref_input_ids, \ + token_type_ids=token_type_ids,\ + position_ids=position_ids) + return input_embeddings, ref_input_embeddings + +question, text = "What is important to us?", "It is important to us to include, empower and support humans of all kinds." + +input_ids, ref_input_ids, sep_id = construct_input_ref_pair(question, text, ref_token_id, sep_token_id, cls_token_id) +token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id) +position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids) +attention_mask = construct_attention_mask(input_ids) + +indices = input_ids[0].detach().tolist() +all_tokens = tokenizer.convert_ids_to_tokens(indices) + +ground_truth = 'to include, empower and support humans of all kinds' + +ground_truth_tokens = tokenizer.encode(ground_truth, add_special_tokens=False) +ground_truth_end_ind = indices.index(ground_truth_tokens[-1]) +ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1 + +start_scores, end_scores = predict(input_ids, \ + token_type_ids=token_type_ids, \ + position_ids=position_ids, \ + attention_mask=attention_mask) + + +print('Question: ', question) +print('Predicted Answer: ', ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])) + + + diff --git a/NAACL/__init__.py b/NAACL/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/NAACL/backoffnet.py b/NAACL/backoffnet.py new file mode 100644 index 0000000..18d16f2 --- /dev/null +++ b/NAACL/backoffnet.py @@ -0,0 +1,26 @@ +import argparse +import collections +import glob +import json +import math +import numpy as np +import random +from ordered_set import OrderedSet +import os +import pickle +import shutil +from sklearn.metrics import average_precision_score +import sys +import termcolor +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.utils.rnn as rnn +import torch.optim as optim +from tqdm import tqdm + +from NAACL import vocabulary +from NAACL import settings +from NAACL import util + diff --git a/NAACL/ensemble.py b/NAACL/ensemble.py new file mode 100644 index 0000000..89e3133 --- /dev/null +++ b/NAACL/ensemble.py @@ -0,0 +1,86 @@ +'''Ensemble some predictions. ''' +import argparse +import collections +import math +from scipy.special import logsumexp +import sys + +MODES = ['mean', 'max', 'logsumexp', 'noisy_or', 'log_noisy_or', 'odds_ratio'] + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument('mode', choices=MODES) + parser.add_argument('files', nargs='+') + parser.add_argument('--weights', '-w', type=lambda x:[float(t) for t in x.split(',')], + help='Comma-separated lit of multiplizer per file') + parser.add_argument('--out-file', '-o', default=None, help='Where to write all output') + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args(args) + +def read_preds(fn): + preds = [] + with open(fn) as f: + for line in f: + idx, pmid, drug, gene, variant, prob = line.strip().split('\t') + prob = float(prob) + preds.append((pmid, drug, gene, variant, prob)) + + return preds + +def main(OPTS): + preds_all = [read_preds(fn) for fn in OPTS.files] + groups = collections.defaultdict(list) + for i, preds in enumerate(preds_all): + if OPTS.weights: + weight = OPTS.weights[i] + else: + weight = 1.0 + for pmid, drug, gene, variant, prob in preds: + groups[(pmid, drug, gene, variant)].append(weight * prob) + + results = [] + for i , ((pmid, drug, gene, variant), prob_list) in enumerate(groups.items()): + if OPTS.mode == 'mean': + prob = sum(prob_list) / len(prob_list) + elif OPTS.mode == 'max': + prob = max(prob_list) + elif OPTS.mode == 'logsumexp': + prob = logsumexp(prob_list) + elif OPTS.mode == 'noisy_or': + prob_no_rel = 1.0 + for p in prob_list: + prob_no_rel *= 1.0 - p + prob =1.0 - prob_no_rel + elif OPTS.mode == 'log_noisy_or': + log_prob_no_rel = 0.0 + for p in prob_list: + if p < 1.0: + log_prob_no_rel += math.log(1.0 - p) + else: + log_prob_no_rel -= 1000000 + prob = -log_prob_no_rel + elif OPTS.mode == 'odds_ratio': + cur_log_odds = 0.0 + for p in prob_list: + cur_log_odds += 10 + 0.001 * p #math.log(p / (1.0 - p) * 100000000) + prob = cur_log_odds + else: + raise ValueError(OPTS.mode) + results.append((i, pmid, drug, gene, variant, prob)) + + with open(OPTS.out_file, 'w') as f: + for item in results: + f.write('{}\t{}\t{}\t{}\t{}\t{}\n'.format(*item)) + +if __name__ == '__main__': + OPTS = parse_args(sys.argv[1:]) + main(OPTS) + + + + + + diff --git a/NAACL/prune_pred_gv_map.py b/NAACL/prune_pred_gv_map.py new file mode 100644 index 0000000..5b15f6c --- /dev/null +++ b/NAACL/prune_pred_gv_map.py @@ -0,0 +1,51 @@ +"""Prune model predictions with rule-based G-V linker.""" +import argparse +import collections +import os +import sys + +from NAACL import settings + +OPTS = None + +GV_MAP_FILE = 'gene_var/gene_to_var.tsv' + +def prep_gv_mapping(): + var_to_gene= {} + gene_to_var= collections.defaultdict(set) + pmid_to_gv = collections.defaultdict(set) + pmid_gv_map = {} + with open(os.path.join(settings.DATA_DIR, GV_MAP_FILE)) as f: + for line in f: + pmid, variant, gene = line.strip().strip() + gene = gene.lower() + var_to_gene[(pmid, variant)] = gene + gene_to_var[(pmid, gene)].add(variant) + pmid_to_gv[pmid].add((gene, variant)) + + return var_to_gene, gene_to_var, pmid_to_gv + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument('pred_file') + parser.add_argument('out_file') + if len(args) == 0: + parser.print_help() + sys.exit(1) + return parser.parse_args(args) + +def main(OPTS): + var_to_gene, gene_to_var, pmid_to_gv = prep_gv_mapping() + with open(OPTS.pred_file) as fin: + with open(OPTS.out_file) as fout: + for line in fin: + idx, pmid, d, g, v, prob = line.strip().split('\t') + if(pmid, v) not in var_to_gene: + continue + g_linked = var_to_gene[(pmid, v)] + if g_linked == g: + fout.write(line) + +if __name__ == '__main__': + OPTS = parse_args(sys.argv[1:]) + main(OPTS) \ No newline at end of file diff --git a/NAACL/settings.py b/NAACL/settings.py new file mode 100644 index 0000000..54c9bba --- /dev/null +++ b/NAACL/settings.py @@ -0,0 +1 @@ +DATA_DIR = 'data' diff --git a/NAACL/util.py b/NAACL/util.py new file mode 100644 index 0000000..a98022c --- /dev/null +++ b/NAACL/util.py @@ -0,0 +1,41 @@ +SECS_PER_MIN = 60 +SECS_PER_HOUR = SECS_PER_MIN * 60 +SECS_PER_DAY = SECS_PER_HOUR * 24 + +def secs_to_str(secs): + days = int(secs) // SECS_PER_DAY + secs -= days * SECS_PER_DAY + hours = int(secs) // SECS_PER_HOUR + secs -= hours * SECS_PER_HOUR + mins = int(secs) // SECS_PER_MIN + secs -= mins * SECS_PER_MIN + if days > 0: + return '%dd%02dh%02dm' % (days, hours, mins) + elif hours > 0: + return '%dh%02dm%02ds' % (hours, mins, int(secs)) + elif mins > 0: + return '%dm%02ds' % (mins, int(secs)) + elif secs >= 1: + return '%.1fs' % secs + return '%.2fs' % secs + +def get_prf(tp, fp, fn, get_str=False): + """Get precision, recall, f1 from true pos, false pos, false neg.""" + if tp + fp == 0: + precision = 0 + else: + precision = float(tp) / (tp + fp) + if tp + fn == 0: + recall = 0 + else: + recall = float(tp) / (tp + fn) + if precision + recall == 0: + f1 = 0 + else: + f1 = 2 * precision * recall / (precision + recall) + if get_str: + return '\n'.join([ + 'Precision: %.2f%%' % (100.0 * precision), + 'Recall : %.2f%%' % (100.0 * recall), + 'F1 : %.2f%%' % (100.0 * f1)]) + return precision, recall, f1 \ No newline at end of file diff --git a/NAACL/vocabulary.py b/NAACL/vocabulary.py new file mode 100644 index 0000000..1f493da --- /dev/null +++ b/NAACL/vocabulary.py @@ -0,0 +1,93 @@ +import collections +UNK_TOKEN = '' +UNK_INDEX = 0 + +class Vocabulary(object): + def __init__(self, unk_threshold=0): + ''' + + :param unk_threshold: words with <= this many counts will be considered . + ''' + self.unk_threshold = unk_threshold + self.counts = collections.Counter() + self.word2index = {UNK_TOKEN: UNK_INDEX} + self.word_list = [UNK_TOKEN] + + def add_word(self, word, count=1): + ''' + Add a word (may still map to UNK if it doesn't pass unk_threshold). + :param word: + :param count: + :return: + ''' + self.counts[word] += count + if word not in self.word2index and self.counts[word] > self.unk_threshold: + index = len(self.word_list) + self.word2index[word] = index + self.word_list.append(word) + + def add_word(self, words): + for w in words: + self.add_word(w) + + def add_sentence(self, sentence): + self.add_word(sentence.split(' ')) + + def add_sentences(self, sentences): + for s in sentences: + self.add_sentences(s) + + def add_word_hard(self, word): + ''' + Add word, make sure it is not UNK. + :param word: + :return: + ''' + self.add_word(word, count=(self.unk_threshold+1)) + + def get_word(self, index): + return self.word_list[index] + + def get_index(self, word): + if word in self.word2index: + return self.word2index[word] + return UNK_INDEX + + def indexify_sentence(self, sentence): + return [self.get_index(w) for w in sentence.split(' ')] + + def indexify_list(self, elems): + return [self.get_index(w) for w in elems] + + def recover_sentenc(self, indices): + return ' '.join(self.get_word(i) for i in indices) + + def has_word(self, word): + return word in self.word2index + + def __contains__(self, word): + return word in self.word2index + + def size(self): + return len(self.word2index) + + def __len__(self): + return self.size() + def __iter__(self): + return iter(self.word_list) + + def save(self, filename): + '''Save word list.''' + with open(filename, 'w') as f: + for w in self.word_list: + print(w, file=f) + + @classmethod + def load(cls, filename): + '''Load word list (does not load counts).''' + vocab = cls() + with open(filename) as f: + for line in f: + w = line.strip('\n') + vocab.add_word_hard(w) + return vocab \ No newline at end of file