Skip to content

Commit

Permalink
Update embeddings_to_torch.py to be more memory efficient by only loa…
Browse files Browse the repository at this point in the history
…ding vectors which are present in the vocab into memory.
  • Loading branch information
DNGros committed Mar 28, 2019
1 parent 6bc8efe commit 57cefb7
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 33 deletions.
2 changes: 1 addition & 1 deletion docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ python preprocess.py \
3) prepare embeddings:

```
./tools/embeddings_to_torch.py -emb_file "glove_dir/glove.6B.100d.txt" \
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
-dict_file "data/data.vocab.pt" \
-output_file "data/embeddings"
```
Expand Down
107 changes: 75 additions & 32 deletions tools/embeddings_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ def get_vocabs(dict_path):
return enc_vocab, dec_vocab


def read_embeddings(file_enc, skip_lines=0):
def read_embeddings(file_enc, skip_lines=0, filter_set=None):
embs = dict()
total_vectors_in_file = 0
with open(file_enc, 'rb') as f:
for i, line in enumerate(f):
if i < skip_lines:
Expand All @@ -47,14 +48,16 @@ def read_embeddings(file_enc, skip_lines=0):
l_split = line.decode('utf8').strip().split(' ')
if len(l_split) == 2:
continue
total_vectors_in_file += 1
if l_split[0] not in filter_set:
continue
embs[l_split[0]] = [float(em) for em in l_split[1:]]
return embs
return embs, total_vectors_in_file


def match_embeddings(vocab, emb, opt):
dim = len(six.next(six.itervalues(emb)))
filtered_embeddings = np.zeros((len(vocab), dim))
count = {"match": 0, "miss": 0}
for w, w_id in vocab.stoi.items():
if w in emb:
filtered_embeddings[w_id] = emb[w]
Expand All @@ -67,12 +70,30 @@ def match_embeddings(vocab, emb, opt):
return torch.Tensor(filtered_embeddings), count


def convert_to_torch_tensor(word_to_float_list_dict, vocab):
dim = len(six.next(six.itervalues(word_to_float_list_dict)))
tensor = torch.zeros((len(vocab), dim))
for word, values in word_to_float_list_dict.items():
tensor[vocab.stoi[word]] = torch.Tensor(values)
return tensor


def calc_vocab_load_stats(vocab, loaded_embed_dict):
matching_count = len(set(vocab.stoi.keys()) & set(loaded_embed_dict.keys()))
missing_count = len(vocab) - matching_count
percet_matching = matching_count / len(vocab) * 100
return matching_count, missing_count, percet_matching


def main():

parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
parser.add_argument('-emb_file_enc', required=True,
parser.add_argument('-emb_file_both', required=False,
help="loads Embeddings for both source and target "
"from this file.")
parser.add_argument('-emb_file_enc', required=False,
help="source Embeddings from this file")
parser.add_argument('-emb_file_dec', required=True,
parser.add_argument('-emb_file_dec', required=False,
help="target Embeddings from this file")
parser.add_argument('-output_file', required=True,
help="Output file for the prepared data")
Expand All @@ -87,41 +108,63 @@ def main():

enc_vocab, dec_vocab = get_vocabs(opt.dict_file)

# Read in embeddings
skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines
src_vectors = read_embeddings(opt.emb_file_enc, skip_lines)
logger.info("Got {} encoder embeddings from {}".format(
len(src_vectors), opt.emb_file_enc))

tgt_vectors = read_embeddings(opt.emb_file_dec)
logger.info("Got {} decoder embeddings from {}".format(
len(tgt_vectors), opt.emb_file_dec))

filtered_enc_embeddings, enc_count = match_embeddings(
enc_vocab, src_vectors, opt)
filtered_dec_embeddings, dec_count = match_embeddings(
dec_vocab, tgt_vectors, opt)
logger.info("\nMatching: ")
match_percent = [_['match'] / (_['match'] + _['miss']) * 100
for _ in [enc_count, dec_count]]
if opt.emb_file_both is not None:
if opt.emb_file_enc is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_enc.")
if opt.emb_file_dec is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_dec.")
set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | \
set(dec_vocab.stoi.keys())
logger.info("Reading encoder and decoder embeddings from {}".format(
opt.emb_file_both))
src_vectors, total_vec_count = \
read_embeddings(opt.emb_file_both, skip_lines, set_of_src_and_tgt_vocab)
tgt_vectors = src_vectors
logger.info("\tFound {} total vectors in file.".format(total_vec_count))
else:
if opt.emb_file_enc is None:
raise ValueError("If --emb_file_enc not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
if opt.emb_file_dec is None:
raise ValueError("If --emb_file_dec not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
logger.info("Reading encoder embeddings from {}".format(opt.emb_file_enc))
src_vectors, total_vec_count = read_embeddings(
opt.emb_file_enc, skip_lines,
filter_set=enc_vocab.stoi
)
logger.info("\tFound {} total vectors in file.".format(total_vec_count))
logger.info("Reading decoder embeddings from {}".format(opt.emb_file_dec))
tgt_vectors, total_vec_count = read_embeddings(
opt.emb_file_dec, skip_lines,
filter_set=dec_vocab.stoi
)
logger.info("\tFound {} total vectors in file.".format(total_vec_count))
logger.info("After filtering to vectors in vocab:")
logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
% (enc_count['match'],
enc_count['miss'],
match_percent[0]))
% calc_vocab_load_stats(enc_vocab, src_vectors))
logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
% (dec_count['match'],
dec_count['miss'],
match_percent[1]))

logger.info("\nFiltered embeddings:")
logger.info("\t* enc: %s" % str(filtered_enc_embeddings.size()))
logger.info("\t* dec: %s" % str(filtered_dec_embeddings.size()))
% calc_vocab_load_stats(dec_vocab, src_vectors))

# Write to file
enc_output_file = opt.output_file + ".enc.pt"
dec_output_file = opt.output_file + ".dec.pt"
logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
% (enc_output_file, dec_output_file))
torch.save(filtered_enc_embeddings, enc_output_file)
torch.save(filtered_dec_embeddings, dec_output_file)
torch.save(
convert_to_torch_tensor(src_vectors, enc_vocab),
enc_output_file
)
torch.save(
convert_to_torch_tensor(tgt_vectors, dec_vocab),
dec_output_file
)
logger.info("\nDone.")


Expand Down

0 comments on commit 57cefb7

Please sign in to comment.