From ba5a8f478784906fddc96ab36d0e63f9f4e34082 Mon Sep 17 00:00:00 2001 From: Daniel Hershcovich Date: Tue, 28 Jul 2015 14:02:32 +0300 Subject: [PATCH] Fix clab/lstm-parser#4: support reading zipped word vectors --- CMakeLists.txt | 2 +- parser/lstm-parse.cc | 48 ++++++++++++++++++++++++++++++-------------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1687a20..12ac42e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,7 @@ if(DEFINED ENV{BOOST_ROOT}) set(Boost_NO_SYSTEM_PATHS ON) endif() set(Boost_REALPATH ON) -find_package(Boost COMPONENTS program_options serialization REQUIRED) +find_package(Boost COMPONENTS program_options serialization iostreams REQUIRED) include_directories(${Boost_INCLUDE_DIR}) set(LIBS ${LIBS} ${Boost_LIBRARIES}) diff --git a/parser/lstm-parse.cc b/parser/lstm-parse.cc index 1eb0548..4bb9f51 100644 --- a/parser/lstm-parse.cc +++ b/parser/lstm-parse.cc @@ -15,8 +15,11 @@ #include #include +#include #include #include +#include +#include #include #include "cnn/training.h" @@ -481,6 +484,21 @@ void output_conll(const vector& sentence, const vector& pos, cout << endl; } +void init_pretrained(istream &in) { + string line; + vector v(PRETRAINED_DIM, 0); + string word; + while (getline(in, line)) { + if (word.empty() && line.find('.') == std::string::npos) + continue; // first line contains vocabulary size and dimensions + istringstream lin(line); + lin >> word; + for (unsigned i = 0; i < PRETRAINED_DIM; ++i) lin >> v[i]; + unsigned id = corpus.get_or_add_word(word); + pretrained[id] = v; + } +} + int main(int argc, char** argv) { cnn::Initialize(argc, argv); @@ -525,24 +543,24 @@ int main(int argc, char** argv) { const string fname = os.str(); cerr << "Writing parameters to file: " << fname << endl; bool softlinkCreated = false; - corpus.load_correct_actions(conf["training_data"].as()); + corpus.load_correct_actions(conf["training_data"].as()); const unsigned kUNK = corpus.get_or_add_word(cpyp::Corpus::UNK); kROOT_SYMBOL = corpus.get_or_add_word(ROOT_SYMBOL); if (conf.count("words")) { pretrained[kUNK] = vector(PRETRAINED_DIM, 0); - cerr << "Loading from " << conf["words"].as() << " with " << PRETRAINED_DIM << " dimensions\n"; - ifstream in(conf["words"].as().c_str()); - string line; - getline(in, line); - vector v(PRETRAINED_DIM, 0); - string word; - while (getline(in, line)) { - istringstream lin(line); - lin >> word; - for (unsigned i = 0; i < PRETRAINED_DIM; ++i) lin >> v[i]; - unsigned id = corpus.get_or_add_word(word); - pretrained[id] = v; + const string& words_fname = conf["words"].as(); + cerr << "Loading from " << words_fname << " with " << PRETRAINED_DIM << " dimensions\n"; + if (boost::algorithm::ends_with(words_fname, ".gz")) { + ifstream file(words_fname.c_str(), ios_base::in | ios_base::binary); + boost::iostreams::filtering_streambuf zip; + zip.push(boost::iostreams::zlib_decompressor()); + zip.push(file); + istream in(&zip); + init_pretrained(in); + } else { + ifstream in(fname.c_str()); + init_pretrained(in); // read as normal text } } @@ -611,7 +629,7 @@ int main(int argc, char** argv) { for (auto& w : tsentence) if (singletons.count(w) && cnn::rand01() < unk_prob) w = kUNK; } - const vector& sentencePos=corpus.sentencesPos[order[si]]; + const vector& sentencePos=corpus.sentencesPos[order[si]]; const vector& actions=corpus.correct_act_sent[order[si]]; ComputationGraph hg; parser.log_prob_parser(&hg,sentence,tsentence,sentencePos,actions,corpus.actions,corpus.intToWords,&right); @@ -644,7 +662,7 @@ int main(int argc, char** argv) { auto t_start = std::chrono::high_resolution_clock::now(); for (unsigned sii = 0; sii < dev_size; ++sii) { const vector& sentence=corpus.sentencesDev[sii]; - const vector& sentencePos=corpus.sentencesPosDev[sii]; + const vector& sentencePos=corpus.sentencesPosDev[sii]; const vector& actions=corpus.correct_act_sentDev[sii]; vector tsentence=sentence; for (auto& w : tsentence)