@@ -73,14 +73,15 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,
73
73
74
74
logits = tf .matmul (output_layer , output_weights , transpose_b = True )
75
75
logits = tf .nn .bias_add (logits , output_bias )
76
+ probabilities = tf .nn .softmax (logits , axis = - 1 )
76
77
log_probs = tf .nn .log_softmax (logits , axis = - 1 )
77
78
78
79
one_hot_labels = tf .one_hot (labels , depth = num_labels , dtype = tf .float32 )
79
80
80
81
per_example_loss = - tf .reduce_sum (one_hot_labels * log_probs , axis = - 1 )
81
82
loss = tf .reduce_mean (per_example_loss )
82
83
83
- return (loss , per_example_loss , logits )
84
+ return (loss , per_example_loss , logits , probabilities )
84
85
85
86
86
87
def model_fn_builder (num_labels , learning_rate , num_train_steps ,
@@ -101,7 +102,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
101
102
102
103
is_training = (mode == tf .estimator .ModeKeys .TRAIN )
103
104
104
- (total_loss , per_example_loss , logits ) = create_model (
105
+ (total_loss , per_example_loss , logits , probabilities ) = create_model (
105
106
is_training , input_ids , input_mask , segment_ids , label_ids , num_labels ,
106
107
bert_hub_module_handle )
107
108
@@ -130,8 +131,12 @@ def metric_fn(per_example_loss, label_ids, logits):
130
131
mode = mode ,
131
132
loss = total_loss ,
132
133
eval_metrics = eval_metrics )
134
+ elif mode == tf .estimator .ModeKeys .PREDICT :
135
+ output_spec = tf .contrib .tpu .TPUEstimatorSpec (
136
+ mode = mode , predictions = {"probabilities" : probabilities })
133
137
else :
134
- raise ValueError ("Only TRAIN and EVAL modes are supported: %s" % (mode ))
138
+ raise ValueError (
139
+ "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode ))
135
140
136
141
return output_spec
137
142
@@ -215,7 +220,8 @@ def main(_):
215
220
model_fn = model_fn ,
216
221
config = run_config ,
217
222
train_batch_size = FLAGS .train_batch_size ,
218
- eval_batch_size = FLAGS .eval_batch_size )
223
+ eval_batch_size = FLAGS .eval_batch_size ,
224
+ predict_batch_size = FLAGS .predict_batch_size )
219
225
220
226
if FLAGS .do_train :
221
227
train_features = run_classifier .convert_examples_to_features (
@@ -265,6 +271,40 @@ def main(_):
265
271
tf .logging .info (" %s = %s" , key , str (result [key ]))
266
272
writer .write ("%s = %s\n " % (key , str (result [key ])))
267
273
274
+ if FLAGS .do_predict :
275
+ predict_examples = processor .get_test_examples (FLAGS .data_dir )
276
+ if FLAGS .use_tpu :
277
+ # Discard batch remainder if running on TPU
278
+ n = len (predict_examples )
279
+ predict_examples = predict_examples [:(n - n % FLAGS .predict_batch_size )]
280
+
281
+ predict_file = os .path .join (FLAGS .output_dir , "predict.tf_record" )
282
+ run_classifier .file_based_convert_examples_to_features (
283
+ predict_examples , label_list , FLAGS .max_seq_length , tokenizer ,
284
+ predict_file )
285
+
286
+ tf .logging .info ("***** Running prediction*****" )
287
+ tf .logging .info (" Num examples = %d" , len (predict_examples ))
288
+ tf .logging .info (" Batch size = %d" , FLAGS .predict_batch_size )
289
+
290
+ predict_input_fn = run_classifier .file_based_input_fn_builder (
291
+ input_file = predict_file ,
292
+ seq_length = FLAGS .max_seq_length ,
293
+ is_training = False ,
294
+ drop_remainder = FLAGS .use_tpu )
295
+
296
+ result = estimator .predict (input_fn = predict_input_fn )
297
+
298
+ output_predict_file = os .path .join (FLAGS .output_dir , "test_results.tsv" )
299
+ with tf .gfile .GFile (output_predict_file , "w" ) as writer :
300
+ tf .logging .info ("***** Predict results *****" )
301
+ for prediction in result :
302
+ probabilities = prediction ["probabilities" ]
303
+ output_line = "\t " .join (
304
+ str (class_probability )
305
+ for class_probability in probabilities ) + "\n "
306
+ writer .write (output_line )
307
+
268
308
269
309
if __name__ == "__main__" :
270
310
flags .mark_flag_as_required ("data_dir" )
0 commit comments