From 8cb45d4d13aa2f3adda88dc387039c28162b6766 Mon Sep 17 00:00:00 2001 From: Daniel Hershcovich Date: Tue, 28 Jul 2015 15:04:49 +0300 Subject: [PATCH] Fix clab/lstm-parser#3: allow limiting optimization by dev uas tolerance --- parser/lstm-parse.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/parser/lstm-parse.cc b/parser/lstm-parse.cc index bf2cc21..49672e0 100644 --- a/parser/lstm-parse.cc +++ b/parser/lstm-parse.cc @@ -71,6 +71,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("lstm_input_dim", po::value()->default_value(60), "LSTM input dimension") ("train,t", "Should training be run?") ("maxit,M", po::value()->default_value(8000), "Maximum number of training iterations") + ("tolerance", po::value()->default_value(0.0), "Tolerance on dev uas for stopping training") ("words,w", po::value(), "Pretrained word embeddings") ("help,h", "Help"); po::options_description dcmdline_options; @@ -525,6 +526,8 @@ int main(int argc, char** argv) { assert(unk_prob >= 0.); assert(unk_prob <= 1.); const unsigned maxit = conf["maxit"].as(); cerr << "Maximum number of iterations: " << maxit << "\n"; + const double tolerance = conf["tolerance"].as(); + cerr << "Optimization tolerance: " << tolerance << "\n"; ostringstream os; os << "parser_" << (USE_POS ? "pos" : "nopos") << '_' << LAYERS @@ -607,9 +610,12 @@ int main(int argc, char** argv) { double llh = 0; bool first = true; unsigned iter = 0; + double uas = -1; + double prev_uas = -1; time_t time_start = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); cerr << "TRAINING STARTED AT: " << put_time(localtime(&time_start), "%c %Z") << endl; - while(!requested_stop && iter < maxit) { + while(!requested_stop && iter < maxit && + (uas < 0 || prev_uas < 0 || abs(prev_uas - uas) > tolerance)) { for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) { if (si == corpus.nsentences) { si = 0; @@ -675,7 +681,9 @@ int main(int argc, char** argv) { total_heads += sentence.size() - 1; } auto t_end = std::chrono::high_resolution_clock::now(); - cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; + prev_uas = uas; + uas = correct_heads / total_heads; + cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << "\t[" << dev_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; if (correct_heads > best_correct_heads) { best_correct_heads = correct_heads; ofstream out(fname); @@ -698,6 +706,8 @@ int main(int argc, char** argv) { } if (iter >= maxit) { cerr << "\nMaximum number of iterations reached (" << iter << "), terminating optimization...\n"; + } else if (!requested_stop) { + cerr << "\nScore tolerance reached (" << tolerance << "), terminating optimization...\n"; } } // should do training? if (true) { // do test evaluation