Skip to content

Commit

Permalink
More efficient embeddings_to_torch.py (OpenNMT#1372)
Browse files Browse the repository at this point in the history
* Update embeddings_to_torch.py to be more memory efficient by only loading vectors which are present in the vocab into memory.

* remove dead code and flake8 violations introduced with 57cefb7

* update docs of using Glove embeddings. Fix spelling error
  • Loading branch information
DNGros authored and vince62s committed Apr 3, 2019
1 parent 9809c4c commit ad50970
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Click this button to open a Workspace on [FloydHub](https://www.floydhub.com/?ut

## Pretrained embeddings (e.g. GloVe)

Go to tutorial: [How to use GloVe pre-trained embeddings in OpenNMT-py](http://forum.opennmt.net/t/how-to-use-glove-pre-trained-embeddings-in-opennmt-py/1011)
Please see the FAQ: [How to use GloVe pre-trained embeddings in OpenNMT-py](http://opennmt.net/OpenNMT-py/FAQ.html#how-do-i-use-pretrained-embeddings-e-g-glove)

## Pretrained Models

Expand Down
16 changes: 8 additions & 8 deletions docs/source/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ the script is a slightly modified version of ylhsieh’s one2.
Usage:

```
embeddings_to_torch.py [-h] -emb_file EMB_FILE -output_file OUTPUT_FILE -dict_file DICT_FILE [-verbose]
emb_file: GloVe like embedding file i.e. CSV [word] [dim1] ... [dim_d]
output_file: a filename to save the output as PyTorch serialized tensors2
dict_file: dict output from OpenNMT-py preprocessing
embeddings_to_torch.py [-h] [-emb_file_both EMB_FILE_BOTH]
[-emb_file_enc EMB_FILE_ENC]
[-emb_file_dec EMB_FILE_DEC] -output_file
OUTPUT_FILE -dict_file DICT_FILE [-verbose]
[-skip_lines SKIP_LINES]
[-type {GloVe,word2vec}]
```
Run embeddings_to_torch.py -h for more usagecomplete info.

Example

Expand All @@ -43,7 +43,7 @@ python preprocess.py \
3) prepare embeddings:

```
./tools/embeddings_to_torch.py -emb_file "glove_dir/glove.6B.100d.txt" \
./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \
-dict_file "data/data.vocab.pt" \
-output_file "data/embeddings"
```
Expand Down
124 changes: 77 additions & 47 deletions tools/embeddings_to_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-
from __future__ import division
import six
import sys
import numpy as np
import argparse
import torch
from onmt.utils.logging import init_logger, logger
Expand Down Expand Up @@ -32,8 +30,9 @@ def get_vocabs(dict_path):
return enc_vocab, dec_vocab


def read_embeddings(file_enc, skip_lines=0):
def read_embeddings(file_enc, skip_lines=0, filter_set=None):
embs = dict()
total_vectors_in_file = 0
with open(file_enc, 'rb') as f:
for i, line in enumerate(f):
if i < skip_lines:
Expand All @@ -47,32 +46,37 @@ def read_embeddings(file_enc, skip_lines=0):
l_split = line.decode('utf8').strip().split(' ')
if len(l_split) == 2:
continue
total_vectors_in_file += 1
if filter_set is not None and l_split[0] not in filter_set:
continue
embs[l_split[0]] = [float(em) for em in l_split[1:]]
return embs
return embs, total_vectors_in_file


def match_embeddings(vocab, emb, opt):
dim = len(six.next(six.itervalues(emb)))
filtered_embeddings = np.zeros((len(vocab), dim))
count = {"match": 0, "miss": 0}
for w, w_id in vocab.stoi.items():
if w in emb:
filtered_embeddings[w_id] = emb[w]
count['match'] += 1
else:
if opt.verbose:
logger.info(u"not found:\t{}".format(w), file=sys.stderr)
count['miss'] += 1
def convert_to_torch_tensor(word_to_float_list_dict, vocab):
dim = len(six.next(six.itervalues(word_to_float_list_dict)))
tensor = torch.zeros((len(vocab), dim))
for word, values in word_to_float_list_dict.items():
tensor[vocab.stoi[word]] = torch.Tensor(values)
return tensor

return torch.Tensor(filtered_embeddings), count

def calc_vocab_load_stats(vocab, loaded_embed_dict):
matching_count = len(
set(vocab.stoi.keys()) & set(loaded_embed_dict.keys()))
missing_count = len(vocab) - matching_count
percent_matching = matching_count / len(vocab) * 100
return matching_count, missing_count, percent_matching

def main():

def main():
parser = argparse.ArgumentParser(description='embeddings_to_torch.py')
parser.add_argument('-emb_file_enc', required=True,
parser.add_argument('-emb_file_both', required=False,
help="loads Embeddings for both source and target "
"from this file.")
parser.add_argument('-emb_file_enc', required=False,
help="source Embeddings from this file")
parser.add_argument('-emb_file_dec', required=True,
parser.add_argument('-emb_file_dec', required=False,
help="target Embeddings from this file")
parser.add_argument('-output_file', required=True,
help="Output file for the prepared data")
Expand All @@ -87,41 +91,67 @@ def main():

enc_vocab, dec_vocab = get_vocabs(opt.dict_file)

# Read in embeddings
skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines
src_vectors = read_embeddings(opt.emb_file_enc, skip_lines)
logger.info("Got {} encoder embeddings from {}".format(
len(src_vectors), opt.emb_file_enc))

tgt_vectors = read_embeddings(opt.emb_file_dec)
logger.info("Got {} decoder embeddings from {}".format(
len(tgt_vectors), opt.emb_file_dec))

filtered_enc_embeddings, enc_count = match_embeddings(
enc_vocab, src_vectors, opt)
filtered_dec_embeddings, dec_count = match_embeddings(
dec_vocab, tgt_vectors, opt)
logger.info("\nMatching: ")
match_percent = [_['match'] / (_['match'] + _['miss']) * 100
for _ in [enc_count, dec_count]]
if opt.emb_file_both is not None:
if opt.emb_file_enc is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_enc.")
if opt.emb_file_dec is not None:
raise ValueError("If --emb_file_both is passed in, you should not"
"set --emb_file_dec.")
set_of_src_and_tgt_vocab = \
set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys())
logger.info("Reading encoder and decoder embeddings from {}".format(
opt.emb_file_both))
src_vectors, total_vec_count = \
read_embeddings(opt.emb_file_both, skip_lines,
set_of_src_and_tgt_vocab)
tgt_vectors = src_vectors
logger.info("\tFound {} total vectors in file".format(total_vec_count))
else:
if opt.emb_file_enc is None:
raise ValueError("If --emb_file_enc not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
if opt.emb_file_dec is None:
raise ValueError("If --emb_file_dec not provided. Please specify "
"the file with encoder embeddings, or pass in "
"--emb_file_both")
logger.info("Reading encoder embeddings from {}".format(
opt.emb_file_enc))
src_vectors, total_vec_count = read_embeddings(
opt.emb_file_enc, skip_lines,
filter_set=enc_vocab.stoi
)
logger.info("\tFound {} total vectors in file.".format(
total_vec_count))
logger.info("Reading decoder embeddings from {}".format(
opt.emb_file_dec))
tgt_vectors, total_vec_count = read_embeddings(
opt.emb_file_dec, skip_lines,
filter_set=dec_vocab.stoi
)
logger.info("\tFound {} total vectors in file".format(total_vec_count))
logger.info("After filtering to vectors in vocab:")
logger.info("\t* enc: %d match, %d missing, (%.2f%%)"
% (enc_count['match'],
enc_count['miss'],
match_percent[0]))
% calc_vocab_load_stats(enc_vocab, src_vectors))
logger.info("\t* dec: %d match, %d missing, (%.2f%%)"
% (dec_count['match'],
dec_count['miss'],
match_percent[1]))

logger.info("\nFiltered embeddings:")
logger.info("\t* enc: %s" % str(filtered_enc_embeddings.size()))
logger.info("\t* dec: %s" % str(filtered_dec_embeddings.size()))
% calc_vocab_load_stats(dec_vocab, src_vectors))

# Write to file
enc_output_file = opt.output_file + ".enc.pt"
dec_output_file = opt.output_file + ".dec.pt"
logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s"
% (enc_output_file, dec_output_file))
torch.save(filtered_enc_embeddings, enc_output_file)
torch.save(filtered_dec_embeddings, dec_output_file)
torch.save(
convert_to_torch_tensor(src_vectors, enc_vocab),
enc_output_file
)
torch.save(
convert_to_torch_tensor(tgt_vectors, dec_vocab),
dec_output_file
)
logger.info("\nDone.")


Expand Down

0 comments on commit ad50970

Please sign in to comment.