From daeca8b232fbe379db6b3a5fdd2fbd78b88ab67c Mon Sep 17 00:00:00 2001 From: Daniel Hershcovich Date: Tue, 28 Jul 2015 15:04:49 +0300 Subject: [PATCH] Fix clab/lstm-parser#3: allow limiting optimization by dev uas tolerance --- parser/lstm-parse.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/parser/lstm-parse.cc b/parser/lstm-parse.cc index d72ddcb..9887bc3 100644 --- a/parser/lstm-parse.cc +++ b/parser/lstm-parse.cc @@ -77,6 +77,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("lstm_input_dim", po::value()->default_value(60), "LSTM input dimension") ("train,t", "Should training be run?") ("maxit,M", po::value()->default_value(8000), "Maximum number of training iterations") + ("tolerance", po::value()->default_value(0.0), "Tolerance on dev uas for stopping training") ("words,w", po::value(), "Pretrained word embeddings") ("use_spelling,S", "Use spelling model") //Miguel. Spelling model ("help,h", "Help"); @@ -946,6 +947,8 @@ int main(int argc, char** argv) { assert(unk_prob >= 0.); assert(unk_prob <= 1.); const unsigned maxit = conf["maxit"].as(); cerr << "Maximum number of iterations: " << maxit << "\n"; + const double tolerance = conf["tolerance"].as(); + cerr << "Optimization tolerance: " << tolerance << "\n"; ostringstream os; os << "parser_" << (USE_POS ? "pos" : "nopos") << '_' << LAYERS @@ -1035,7 +1038,10 @@ int main(int argc, char** argv) { double llh = 0; bool first = true; unsigned iter = 0; - while(!requested_stop && iter < maxit) { + double uas = -1; + double prev_uas = -1; + while(!requested_stop && iter < maxit && + (uas < 0 || prev_uas < 0 || abs(prev_uas - uas) > tolerance)) { for (unsigned sii = 0; sii < status_every_i_iterations; ++sii) { if (si == corpus.nsentences) { si = 0; @@ -1103,7 +1109,9 @@ int main(int argc, char** argv) { total_heads += sentence.size() - 1; } auto t_end = std::chrono::high_resolution_clock::now(); - cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << (correct_heads / total_heads) << "\t[" << dev_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; + prev_uas = uas; + uas = correct_heads / total_heads; + cerr << " **dev (iter=" << iter << " epoch=" << (tot_seen / corpus.nsentences) << ")\tllh=" << llh << " ppl: " << exp(llh / trs) << " err: " << (trs - right) / trs << " uas: " << uas << "\t[" << dev_size << " sents in " << std::chrono::duration(t_end-t_start).count() << " ms]" << endl; if (correct_heads > best_correct_heads) { best_correct_heads = correct_heads; ofstream out(fname); @@ -1126,6 +1134,8 @@ int main(int argc, char** argv) { } if (iter >= maxit) { cerr << "\nMaximum number of iterations reached (" << iter << "), terminating optimization...\n"; + } else if (!requested_stop) { + cerr << "\nScore tolerance reached (" << tolerance << "), terminating optimization...\n"; } } // should do training? if (true) { // do test evaluation