diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 5f039ecb37..6164f61c9d 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -43,7 +43,7 @@ python preprocess.py \ 3) prepare embeddings: ``` -./tools/embeddings_to_torch.py -emb_file "glove_dir/glove.6B.100d.txt" \ +./tools/embeddings_to_torch.py -emb_file_both "glove_dir/glove.6B.100d.txt" \ -dict_file "data/data.vocab.pt" \ -output_file "data/embeddings" ``` diff --git a/tools/embeddings_to_torch.py b/tools/embeddings_to_torch.py index e74f866ef5..4dd7dce34f 100755 --- a/tools/embeddings_to_torch.py +++ b/tools/embeddings_to_torch.py @@ -32,8 +32,9 @@ def get_vocabs(dict_path): return enc_vocab, dec_vocab -def read_embeddings(file_enc, skip_lines=0): +def read_embeddings(file_enc, skip_lines=0, filter_set=None): embs = dict() + total_vectors_in_file = 0 with open(file_enc, 'rb') as f: for i, line in enumerate(f): if i < skip_lines: @@ -47,14 +48,16 @@ def read_embeddings(file_enc, skip_lines=0): l_split = line.decode('utf8').strip().split(' ') if len(l_split) == 2: continue + total_vectors_in_file += 1 + if l_split[0] not in filter_set: + continue embs[l_split[0]] = [float(em) for em in l_split[1:]] - return embs + return embs, total_vectors_in_file def match_embeddings(vocab, emb, opt): dim = len(six.next(six.itervalues(emb))) filtered_embeddings = np.zeros((len(vocab), dim)) - count = {"match": 0, "miss": 0} for w, w_id in vocab.stoi.items(): if w in emb: filtered_embeddings[w_id] = emb[w] @@ -67,12 +70,30 @@ def match_embeddings(vocab, emb, opt): return torch.Tensor(filtered_embeddings), count +def convert_to_torch_tensor(word_to_float_list_dict, vocab): + dim = len(six.next(six.itervalues(word_to_float_list_dict))) + tensor = torch.zeros((len(vocab), dim)) + for word, values in word_to_float_list_dict.items(): + tensor[vocab.stoi[word]] = torch.Tensor(values) + return tensor + + +def calc_vocab_load_stats(vocab, loaded_embed_dict): + matching_count = len(set(vocab.stoi.keys()) & set(loaded_embed_dict.keys())) + missing_count = len(vocab) - matching_count + percet_matching = matching_count / len(vocab) * 100 + return matching_count, missing_count, percet_matching + + def main(): parser = argparse.ArgumentParser(description='embeddings_to_torch.py') - parser.add_argument('-emb_file_enc', required=True, + parser.add_argument('-emb_file_both', required=False, + help="loads Embeddings for both source and target " + "from this file.") + parser.add_argument('-emb_file_enc', required=False, help="source Embeddings from this file") - parser.add_argument('-emb_file_dec', required=True, + parser.add_argument('-emb_file_dec', required=False, help="target Embeddings from this file") parser.add_argument('-output_file', required=True, help="Output file for the prepared data") @@ -87,41 +108,63 @@ def main(): enc_vocab, dec_vocab = get_vocabs(opt.dict_file) + # Read in embeddings skip_lines = 1 if opt.type == "word2vec" else opt.skip_lines - src_vectors = read_embeddings(opt.emb_file_enc, skip_lines) - logger.info("Got {} encoder embeddings from {}".format( - len(src_vectors), opt.emb_file_enc)) - - tgt_vectors = read_embeddings(opt.emb_file_dec) - logger.info("Got {} decoder embeddings from {}".format( - len(tgt_vectors), opt.emb_file_dec)) - - filtered_enc_embeddings, enc_count = match_embeddings( - enc_vocab, src_vectors, opt) - filtered_dec_embeddings, dec_count = match_embeddings( - dec_vocab, tgt_vectors, opt) - logger.info("\nMatching: ") - match_percent = [_['match'] / (_['match'] + _['miss']) * 100 - for _ in [enc_count, dec_count]] + if opt.emb_file_both is not None: + if opt.emb_file_enc is not None: + raise ValueError("If --emb_file_both is passed in, you should not" + "set --emb_file_enc.") + if opt.emb_file_dec is not None: + raise ValueError("If --emb_file_both is passed in, you should not" + "set --emb_file_dec.") + set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | \ + set(dec_vocab.stoi.keys()) + logger.info("Reading encoder and decoder embeddings from {}".format( + opt.emb_file_both)) + src_vectors, total_vec_count = \ + read_embeddings(opt.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) + tgt_vectors = src_vectors + logger.info("\tFound {} total vectors in file.".format(total_vec_count)) + else: + if opt.emb_file_enc is None: + raise ValueError("If --emb_file_enc not provided. Please specify " + "the file with encoder embeddings, or pass in " + "--emb_file_both") + if opt.emb_file_dec is None: + raise ValueError("If --emb_file_dec not provided. Please specify " + "the file with encoder embeddings, or pass in " + "--emb_file_both") + logger.info("Reading encoder embeddings from {}".format(opt.emb_file_enc)) + src_vectors, total_vec_count = read_embeddings( + opt.emb_file_enc, skip_lines, + filter_set=enc_vocab.stoi + ) + logger.info("\tFound {} total vectors in file.".format(total_vec_count)) + logger.info("Reading decoder embeddings from {}".format(opt.emb_file_dec)) + tgt_vectors, total_vec_count = read_embeddings( + opt.emb_file_dec, skip_lines, + filter_set=dec_vocab.stoi + ) + logger.info("\tFound {} total vectors in file.".format(total_vec_count)) + logger.info("After filtering to vectors in vocab:") logger.info("\t* enc: %d match, %d missing, (%.2f%%)" - % (enc_count['match'], - enc_count['miss'], - match_percent[0])) + % calc_vocab_load_stats(enc_vocab, src_vectors)) logger.info("\t* dec: %d match, %d missing, (%.2f%%)" - % (dec_count['match'], - dec_count['miss'], - match_percent[1])) - - logger.info("\nFiltered embeddings:") - logger.info("\t* enc: %s" % str(filtered_enc_embeddings.size())) - logger.info("\t* dec: %s" % str(filtered_dec_embeddings.size())) + % calc_vocab_load_stats(dec_vocab, src_vectors)) + # Write to file enc_output_file = opt.output_file + ".enc.pt" dec_output_file = opt.output_file + ".dec.pt" logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" % (enc_output_file, dec_output_file)) - torch.save(filtered_enc_embeddings, enc_output_file) - torch.save(filtered_dec_embeddings, dec_output_file) + torch.save( + convert_to_torch_tensor(src_vectors, enc_vocab), + enc_output_file + ) + torch.save( + convert_to_torch_tensor(tgt_vectors, dec_vocab), + dec_output_file + ) logger.info("\nDone.")