Skip to content
This repository was archived by the owner on Dec 11, 2023. It is now read-only.

Commit 0642c53

Browse files
eli7lmthang
authored andcommitted
Refactoring internal and external eval to allow injection of placeholders tensors.
PiperOrigin-RevId: 183781004
1 parent 005fef0 commit 0642c53

File tree

1 file changed

+176
-32
lines changed

1 file changed

+176
-32
lines changed

nmt/train.py

Lines changed: 176 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,49 @@ def run_sample_decode(infer_model, infer_sess, model_dir, hparams,
5353
infer_model.batch_size_placeholder, summary_writer)
5454

5555

56-
def run_internal_eval(
57-
eval_model, eval_sess, model_dir, hparams, summary_writer,
58-
use_test_set=True):
59-
"""Compute internal evaluation (perplexity) for both dev / test."""
56+
def run_internal_eval(eval_model,
57+
eval_sess,
58+
model_dir,
59+
hparams,
60+
summary_writer,
61+
use_test_set=True,
62+
dev_eval_iterator_feed_dict=None,
63+
test_eval_iterator_feed_dict=None):
64+
"""Compute internal evaluation (perplexity) for both dev / test.
65+
66+
Computes development and testing perplexities for given model.
67+
68+
Args:
69+
eval_model: Evaluation model for which to compute perplexities.
70+
eval_sess: Evaluation TensorFlow session.
71+
model_dir: Directory from which to load evaluation model from.
72+
hparams: Model hyper-parameters.
73+
summary_writer: Summary writer for logging metrics to TensorBoard.
74+
use_test_set: Computes testing perplexity if true; does not otherwise.
75+
Note that the development perplexity is always computed regardless of
76+
value of this parameter.
77+
dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
78+
Can be used to pass in additional inputs necessary for running the
79+
development evaluation.
80+
test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
81+
Can be used to pass in additional inputs necessary for running the
82+
testing evaluation.
83+
Returns:
84+
Pair containing development perplexity and testing perplexity, in this
85+
order.
86+
"""
87+
if dev_eval_iterator_feed_dict is None:
88+
dev_eval_iterator_feed_dict = {}
89+
if test_eval_iterator_feed_dict is None:
90+
test_eval_iterator_feed_dict = {}
6091
with eval_model.graph.as_default():
6192
loaded_eval_model, global_step = model_helper.create_or_load_model(
6293
eval_model.model, model_dir, eval_sess, "eval")
6394

6495
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
6596
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
66-
dev_eval_iterator_feed_dict = {
67-
eval_model.src_file_placeholder: dev_src_file,
68-
eval_model.tgt_file_placeholder: dev_tgt_file
69-
}
97+
dev_eval_iterator_feed_dict[eval_model.src_file_placeholder] = dev_src_file
98+
dev_eval_iterator_feed_dict[eval_model.tgt_file_placeholder] = dev_tgt_file
7099

71100
dev_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
72101
eval_model.iterator, dev_eval_iterator_feed_dict,
@@ -75,30 +104,64 @@ def run_internal_eval(
75104
if use_test_set and hparams.test_prefix:
76105
test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
77106
test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
78-
test_eval_iterator_feed_dict = {
79-
eval_model.src_file_placeholder: test_src_file,
80-
eval_model.tgt_file_placeholder: test_tgt_file
81-
}
107+
test_eval_iterator_feed_dict[
108+
eval_model.src_file_placeholder] = test_src_file
109+
test_eval_iterator_feed_dict[
110+
eval_model.tgt_file_placeholder] = test_tgt_file
82111
test_ppl = _internal_eval(loaded_eval_model, global_step, eval_sess,
83112
eval_model.iterator, test_eval_iterator_feed_dict,
84113
summary_writer, "test")
85114
return dev_ppl, test_ppl
86115

87116

88-
def run_external_eval(infer_model, infer_sess, model_dir, hparams,
89-
summary_writer, save_best_dev=True, use_test_set=True,
90-
avg_ckpts=False):
91-
"""Compute external evaluation (bleu, rouge, etc.) for both dev / test."""
117+
def run_external_eval(infer_model,
118+
infer_sess,
119+
model_dir,
120+
hparams,
121+
summary_writer,
122+
save_best_dev=True,
123+
use_test_set=True,
124+
avg_ckpts=False,
125+
dev_infer_iterator_feed_dict=None,
126+
test_infer_iterator_feed_dict=None):
127+
"""Compute external evaluation for both dev / test.
128+
129+
Computes development and testing external evaluation (e.g. bleu, rouge) for
130+
given model.
131+
132+
Args:
133+
infer_model: Inference model for which to compute perplexities.
134+
infer_sess: Inference TensorFlow session.
135+
model_dir: Directory from which to load inference model from.
136+
hparams: Model hyper-parameters.
137+
summary_writer: Summary writer for logging metrics to TensorBoard.
138+
use_test_set: Computes testing external evaluation if true; does not
139+
otherwise. Note that the development external evaluation is always
140+
computed regardless of value of this parameter.
141+
dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
142+
Can be used to pass in additional inputs necessary for running the
143+
development external evaluation.
144+
test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
145+
Can be used to pass in additional inputs necessary for running the
146+
testing external evaluation.
147+
Returns:
148+
Triple containing development scores, testing scores and the TensorFlow
149+
Variable for the global step number, in this order.
150+
"""
151+
if dev_infer_iterator_feed_dict is None:
152+
dev_infer_iterator_feed_dict = {}
153+
if test_infer_iterator_feed_dict is None:
154+
test_infer_iterator_feed_dict = {}
92155
with infer_model.graph.as_default():
93156
loaded_infer_model, global_step = model_helper.create_or_load_model(
94157
infer_model.model, model_dir, infer_sess, "infer")
95158

96159
dev_src_file = "%s.%s" % (hparams.dev_prefix, hparams.src)
97160
dev_tgt_file = "%s.%s" % (hparams.dev_prefix, hparams.tgt)
98-
dev_infer_iterator_feed_dict = {
99-
infer_model.src_placeholder: inference.load_data(dev_src_file),
100-
infer_model.batch_size_placeholder: hparams.infer_batch_size,
101-
}
161+
dev_infer_iterator_feed_dict[
162+
infer_model.src_placeholder] = inference.load_data(dev_src_file)
163+
dev_infer_iterator_feed_dict[
164+
infer_model.batch_size_placeholder] = hparams.infer_batch_size
102165
dev_scores = _external_eval(
103166
loaded_infer_model,
104167
global_step,
@@ -116,10 +179,10 @@ def run_external_eval(infer_model, infer_sess, model_dir, hparams,
116179
if use_test_set and hparams.test_prefix:
117180
test_src_file = "%s.%s" % (hparams.test_prefix, hparams.src)
118181
test_tgt_file = "%s.%s" % (hparams.test_prefix, hparams.tgt)
119-
test_infer_iterator_feed_dict = {
120-
infer_model.src_placeholder: inference.load_data(test_src_file),
121-
infer_model.batch_size_placeholder: hparams.infer_batch_size,
122-
}
182+
test_infer_iterator_feed_dict[
183+
infer_model.src_placeholder] = inference.load_data(test_src_file)
184+
test_infer_iterator_feed_dict[
185+
infer_model.batch_size_placeholder] = hparams.infer_batch_size
123186
test_scores = _external_eval(
124187
loaded_infer_model,
125188
global_step,
@@ -157,16 +220,63 @@ def run_avg_external_eval(infer_model, infer_sess, model_dir, hparams,
157220
return avg_dev_scores, avg_test_scores
158221

159222

160-
def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
161-
hparams, summary_writer, sample_src_data, sample_tgt_data,
162-
avg_ckpts=False):
163-
"""Wrapper for running sample_decode, internal_eval and external_eval."""
164-
run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
165-
sample_src_data, sample_tgt_data)
223+
def run_internal_and_external_eval(model_dir,
224+
infer_model,
225+
infer_sess,
226+
eval_model,
227+
eval_sess,
228+
hparams,
229+
summary_writer,
230+
avg_ckpts=False,
231+
dev_eval_iterator_feed_dict=None,
232+
test_eval_iterator_feed_dict=None,
233+
dev_infer_iterator_feed_dict=None,
234+
test_infer_iterator_feed_dict=None):
235+
"""Compute internal evaluation (perplexity) for both dev / test.
236+
237+
Computes development and testing perplexities for given model.
238+
239+
Args:
240+
model_dir: Directory from which to load models from.
241+
infer_model: Inference model for which to compute perplexities.
242+
infer_sess: Inference TensorFlow session.
243+
eval_model: Evaluation model for which to compute perplexities.
244+
eval_sess: Evaluation TensorFlow session.
245+
hparams: Model hyper-parameters.
246+
summary_writer: Summary writer for logging metrics to TensorBoard.
247+
avg_ckpts: Whether to compute average external evaluation scores.
248+
dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
249+
Can be used to pass in additional inputs necessary for running the
250+
internal development evaluation.
251+
test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
252+
Can be used to pass in additional inputs necessary for running the
253+
internal testing evaluation.
254+
dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
255+
Can be used to pass in additional inputs necessary for running the
256+
external development evaluation.
257+
test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
258+
Can be used to pass in additional inputs necessary for running the
259+
external testing evaluation.
260+
Returns:
261+
Triple containing results summary, global step Tensorflow Variable and
262+
metrics in this order.
263+
"""
166264
dev_ppl, test_ppl = run_internal_eval(
167-
eval_model, eval_sess, model_dir, hparams, summary_writer)
265+
eval_model,
266+
eval_sess,
267+
model_dir,
268+
hparams,
269+
summary_writer,
270+
dev_eval_iterator_feed_dict=dev_eval_iterator_feed_dict,
271+
test_eval_iterator_feed_dict=test_eval_iterator_feed_dict)
168272
dev_scores, test_scores, global_step = run_external_eval(
169-
infer_model, infer_sess, model_dir, hparams, summary_writer)
273+
infer_model,
274+
infer_sess,
275+
model_dir,
276+
hparams,
277+
summary_writer,
278+
dev_infer_iterator_feed_dict=dev_infer_iterator_feed_dict,
279+
test_infer_iterator_feed_dict=test_infer_iterator_feed_dict)
170280

171281
metrics = {
172282
"dev_ppl": dev_ppl,
@@ -197,6 +307,40 @@ def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
197307
return result_summary, global_step, metrics
198308

199309

310+
def run_full_eval(model_dir,
311+
infer_model,
312+
infer_sess,
313+
eval_model,
314+
eval_sess,
315+
hparams,
316+
summary_writer,
317+
sample_src_data,
318+
sample_tgt_data,
319+
avg_ckpts=False):
320+
"""Wrapper for running sample_decode, internal_eval and external_eval.
321+
322+
Args:
323+
model_dir: Directory from which to load models from.
324+
infer_model: Inference model for which to compute perplexities.
325+
infer_sess: Inference TensorFlow session.
326+
eval_model: Evaluation model for which to compute perplexities.
327+
eval_sess: Evaluation TensorFlow session.
328+
hparams: Model hyper-parameters.
329+
summary_writer: Summary writer for logging metrics to TensorBoard.
330+
sample_src_data: sample of source data for sample decoding.
331+
sample_tgt_data: sample of target data for sample decoding.
332+
avg_ckpts: Whether to compute average external evaluation scores.
333+
Returns:
334+
Triple containing results summary, global step Tensorflow Variable and
335+
metrics in this order.
336+
"""
337+
run_sample_decode(infer_model, infer_sess, model_dir, hparams, summary_writer,
338+
sample_src_data, sample_tgt_data)
339+
return run_internal_and_external_eval(model_dir, infer_model, infer_sess,
340+
eval_model, eval_sess, hparams,
341+
summary_writer, avg_ckpts)
342+
343+
200344
def init_stats():
201345
"""Initialize statistics that we want to accumulate."""
202346
return {"step_time": 0.0, "train_loss": 0.0,

0 commit comments

Comments
 (0)