@@ -53,20 +53,49 @@ def run_sample_decode(infer_model, infer_sess, model_dir, hparams,
53
53
infer_model .batch_size_placeholder , summary_writer )
54
54
55
55
56
- def run_internal_eval (
57
- eval_model , eval_sess , model_dir , hparams , summary_writer ,
58
- use_test_set = True ):
59
- """Compute internal evaluation (perplexity) for both dev / test."""
56
+ def run_internal_eval (eval_model ,
57
+ eval_sess ,
58
+ model_dir ,
59
+ hparams ,
60
+ summary_writer ,
61
+ use_test_set = True ,
62
+ dev_eval_iterator_feed_dict = None ,
63
+ test_eval_iterator_feed_dict = None ):
64
+ """Compute internal evaluation (perplexity) for both dev / test.
65
+
66
+ Computes development and testing perplexities for given model.
67
+
68
+ Args:
69
+ eval_model: Evaluation model for which to compute perplexities.
70
+ eval_sess: Evaluation TensorFlow session.
71
+ model_dir: Directory from which to load evaluation model from.
72
+ hparams: Model hyper-parameters.
73
+ summary_writer: Summary writer for logging metrics to TensorBoard.
74
+ use_test_set: Computes testing perplexity if true; does not otherwise.
75
+ Note that the development perplexity is always computed regardless of
76
+ value of this parameter.
77
+ dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
78
+ Can be used to pass in additional inputs necessary for running the
79
+ development evaluation.
80
+ test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
81
+ Can be used to pass in additional inputs necessary for running the
82
+ testing evaluation.
83
+ Returns:
84
+ Pair containing development perplexity and testing perplexity, in this
85
+ order.
86
+ """
87
+ if dev_eval_iterator_feed_dict is None :
88
+ dev_eval_iterator_feed_dict = {}
89
+ if test_eval_iterator_feed_dict is None :
90
+ test_eval_iterator_feed_dict = {}
60
91
with eval_model .graph .as_default ():
61
92
loaded_eval_model , global_step = model_helper .create_or_load_model (
62
93
eval_model .model , model_dir , eval_sess , "eval" )
63
94
64
95
dev_src_file = "%s.%s" % (hparams .dev_prefix , hparams .src )
65
96
dev_tgt_file = "%s.%s" % (hparams .dev_prefix , hparams .tgt )
66
- dev_eval_iterator_feed_dict = {
67
- eval_model .src_file_placeholder : dev_src_file ,
68
- eval_model .tgt_file_placeholder : dev_tgt_file
69
- }
97
+ dev_eval_iterator_feed_dict [eval_model .src_file_placeholder ] = dev_src_file
98
+ dev_eval_iterator_feed_dict [eval_model .tgt_file_placeholder ] = dev_tgt_file
70
99
71
100
dev_ppl = _internal_eval (loaded_eval_model , global_step , eval_sess ,
72
101
eval_model .iterator , dev_eval_iterator_feed_dict ,
@@ -75,30 +104,64 @@ def run_internal_eval(
75
104
if use_test_set and hparams .test_prefix :
76
105
test_src_file = "%s.%s" % (hparams .test_prefix , hparams .src )
77
106
test_tgt_file = "%s.%s" % (hparams .test_prefix , hparams .tgt )
78
- test_eval_iterator_feed_dict = {
79
- eval_model .src_file_placeholder : test_src_file ,
80
- eval_model . tgt_file_placeholder : test_tgt_file
81
- }
107
+ test_eval_iterator_feed_dict [
108
+ eval_model .src_file_placeholder ] = test_src_file
109
+ test_eval_iterator_feed_dict [
110
+ eval_model . tgt_file_placeholder ] = test_tgt_file
82
111
test_ppl = _internal_eval (loaded_eval_model , global_step , eval_sess ,
83
112
eval_model .iterator , test_eval_iterator_feed_dict ,
84
113
summary_writer , "test" )
85
114
return dev_ppl , test_ppl
86
115
87
116
88
- def run_external_eval (infer_model , infer_sess , model_dir , hparams ,
89
- summary_writer , save_best_dev = True , use_test_set = True ,
90
- avg_ckpts = False ):
91
- """Compute external evaluation (bleu, rouge, etc.) for both dev / test."""
117
+ def run_external_eval (infer_model ,
118
+ infer_sess ,
119
+ model_dir ,
120
+ hparams ,
121
+ summary_writer ,
122
+ save_best_dev = True ,
123
+ use_test_set = True ,
124
+ avg_ckpts = False ,
125
+ dev_infer_iterator_feed_dict = None ,
126
+ test_infer_iterator_feed_dict = None ):
127
+ """Compute external evaluation for both dev / test.
128
+
129
+ Computes development and testing external evaluation (e.g. bleu, rouge) for
130
+ given model.
131
+
132
+ Args:
133
+ infer_model: Inference model for which to compute perplexities.
134
+ infer_sess: Inference TensorFlow session.
135
+ model_dir: Directory from which to load inference model from.
136
+ hparams: Model hyper-parameters.
137
+ summary_writer: Summary writer for logging metrics to TensorBoard.
138
+ use_test_set: Computes testing external evaluation if true; does not
139
+ otherwise. Note that the development external evaluation is always
140
+ computed regardless of value of this parameter.
141
+ dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
142
+ Can be used to pass in additional inputs necessary for running the
143
+ development external evaluation.
144
+ test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
145
+ Can be used to pass in additional inputs necessary for running the
146
+ testing external evaluation.
147
+ Returns:
148
+ Triple containing development scores, testing scores and the TensorFlow
149
+ Variable for the global step number, in this order.
150
+ """
151
+ if dev_infer_iterator_feed_dict is None :
152
+ dev_infer_iterator_feed_dict = {}
153
+ if test_infer_iterator_feed_dict is None :
154
+ test_infer_iterator_feed_dict = {}
92
155
with infer_model .graph .as_default ():
93
156
loaded_infer_model , global_step = model_helper .create_or_load_model (
94
157
infer_model .model , model_dir , infer_sess , "infer" )
95
158
96
159
dev_src_file = "%s.%s" % (hparams .dev_prefix , hparams .src )
97
160
dev_tgt_file = "%s.%s" % (hparams .dev_prefix , hparams .tgt )
98
- dev_infer_iterator_feed_dict = {
99
- infer_model .src_placeholder : inference .load_data (dev_src_file ),
100
- infer_model . batch_size_placeholder : hparams . infer_batch_size ,
101
- }
161
+ dev_infer_iterator_feed_dict [
162
+ infer_model .src_placeholder ] = inference .load_data (dev_src_file )
163
+ dev_infer_iterator_feed_dict [
164
+ infer_model . batch_size_placeholder ] = hparams . infer_batch_size
102
165
dev_scores = _external_eval (
103
166
loaded_infer_model ,
104
167
global_step ,
@@ -116,10 +179,10 @@ def run_external_eval(infer_model, infer_sess, model_dir, hparams,
116
179
if use_test_set and hparams .test_prefix :
117
180
test_src_file = "%s.%s" % (hparams .test_prefix , hparams .src )
118
181
test_tgt_file = "%s.%s" % (hparams .test_prefix , hparams .tgt )
119
- test_infer_iterator_feed_dict = {
120
- infer_model .src_placeholder : inference .load_data (test_src_file ),
121
- infer_model . batch_size_placeholder : hparams . infer_batch_size ,
122
- }
182
+ test_infer_iterator_feed_dict [
183
+ infer_model .src_placeholder ] = inference .load_data (test_src_file )
184
+ test_infer_iterator_feed_dict [
185
+ infer_model . batch_size_placeholder ] = hparams . infer_batch_size
123
186
test_scores = _external_eval (
124
187
loaded_infer_model ,
125
188
global_step ,
@@ -157,16 +220,63 @@ def run_avg_external_eval(infer_model, infer_sess, model_dir, hparams,
157
220
return avg_dev_scores , avg_test_scores
158
221
159
222
160
- def run_full_eval (model_dir , infer_model , infer_sess , eval_model , eval_sess ,
161
- hparams , summary_writer , sample_src_data , sample_tgt_data ,
162
- avg_ckpts = False ):
163
- """Wrapper for running sample_decode, internal_eval and external_eval."""
164
- run_sample_decode (infer_model , infer_sess , model_dir , hparams , summary_writer ,
165
- sample_src_data , sample_tgt_data )
223
+ def run_internal_and_external_eval (model_dir ,
224
+ infer_model ,
225
+ infer_sess ,
226
+ eval_model ,
227
+ eval_sess ,
228
+ hparams ,
229
+ summary_writer ,
230
+ avg_ckpts = False ,
231
+ dev_eval_iterator_feed_dict = None ,
232
+ test_eval_iterator_feed_dict = None ,
233
+ dev_infer_iterator_feed_dict = None ,
234
+ test_infer_iterator_feed_dict = None ):
235
+ """Compute internal evaluation (perplexity) for both dev / test.
236
+
237
+ Computes development and testing perplexities for given model.
238
+
239
+ Args:
240
+ model_dir: Directory from which to load models from.
241
+ infer_model: Inference model for which to compute perplexities.
242
+ infer_sess: Inference TensorFlow session.
243
+ eval_model: Evaluation model for which to compute perplexities.
244
+ eval_sess: Evaluation TensorFlow session.
245
+ hparams: Model hyper-parameters.
246
+ summary_writer: Summary writer for logging metrics to TensorBoard.
247
+ avg_ckpts: Whether to compute average external evaluation scores.
248
+ dev_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
249
+ Can be used to pass in additional inputs necessary for running the
250
+ internal development evaluation.
251
+ test_eval_iterator_feed_dict: Feed dictionary for a TensorFlow session.
252
+ Can be used to pass in additional inputs necessary for running the
253
+ internal testing evaluation.
254
+ dev_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
255
+ Can be used to pass in additional inputs necessary for running the
256
+ external development evaluation.
257
+ test_infer_iterator_feed_dict: Feed dictionary for a TensorFlow session.
258
+ Can be used to pass in additional inputs necessary for running the
259
+ external testing evaluation.
260
+ Returns:
261
+ Triple containing results summary, global step Tensorflow Variable and
262
+ metrics in this order.
263
+ """
166
264
dev_ppl , test_ppl = run_internal_eval (
167
- eval_model , eval_sess , model_dir , hparams , summary_writer )
265
+ eval_model ,
266
+ eval_sess ,
267
+ model_dir ,
268
+ hparams ,
269
+ summary_writer ,
270
+ dev_eval_iterator_feed_dict = dev_eval_iterator_feed_dict ,
271
+ test_eval_iterator_feed_dict = test_eval_iterator_feed_dict )
168
272
dev_scores , test_scores , global_step = run_external_eval (
169
- infer_model , infer_sess , model_dir , hparams , summary_writer )
273
+ infer_model ,
274
+ infer_sess ,
275
+ model_dir ,
276
+ hparams ,
277
+ summary_writer ,
278
+ dev_infer_iterator_feed_dict = dev_infer_iterator_feed_dict ,
279
+ test_infer_iterator_feed_dict = test_infer_iterator_feed_dict )
170
280
171
281
metrics = {
172
282
"dev_ppl" : dev_ppl ,
@@ -197,6 +307,40 @@ def run_full_eval(model_dir, infer_model, infer_sess, eval_model, eval_sess,
197
307
return result_summary , global_step , metrics
198
308
199
309
310
+ def run_full_eval (model_dir ,
311
+ infer_model ,
312
+ infer_sess ,
313
+ eval_model ,
314
+ eval_sess ,
315
+ hparams ,
316
+ summary_writer ,
317
+ sample_src_data ,
318
+ sample_tgt_data ,
319
+ avg_ckpts = False ):
320
+ """Wrapper for running sample_decode, internal_eval and external_eval.
321
+
322
+ Args:
323
+ model_dir: Directory from which to load models from.
324
+ infer_model: Inference model for which to compute perplexities.
325
+ infer_sess: Inference TensorFlow session.
326
+ eval_model: Evaluation model for which to compute perplexities.
327
+ eval_sess: Evaluation TensorFlow session.
328
+ hparams: Model hyper-parameters.
329
+ summary_writer: Summary writer for logging metrics to TensorBoard.
330
+ sample_src_data: sample of source data for sample decoding.
331
+ sample_tgt_data: sample of target data for sample decoding.
332
+ avg_ckpts: Whether to compute average external evaluation scores.
333
+ Returns:
334
+ Triple containing results summary, global step Tensorflow Variable and
335
+ metrics in this order.
336
+ """
337
+ run_sample_decode (infer_model , infer_sess , model_dir , hparams , summary_writer ,
338
+ sample_src_data , sample_tgt_data )
339
+ return run_internal_and_external_eval (model_dir , infer_model , infer_sess ,
340
+ eval_model , eval_sess , hparams ,
341
+ summary_writer , avg_ckpts )
342
+
343
+
200
344
def init_stats ():
201
345
"""Initialize statistics that we want to accumulate."""
202
346
return {"step_time" : 0.0 , "train_loss" : 0.0 ,
0 commit comments