Skip to content

Commit d66a146

Browse files
(1) Updating TF Hub classifier (2) Updating tokenizer to support emojis
1 parent 7c1a4bf commit d66a146

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

run_classifier_with_tfhub.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,15 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,
7373

7474
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
7575
logits = tf.nn.bias_add(logits, output_bias)
76+
probabilities = tf.nn.softmax(logits, axis=-1)
7677
log_probs = tf.nn.log_softmax(logits, axis=-1)
7778

7879
one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
7980

8081
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
8182
loss = tf.reduce_mean(per_example_loss)
8283

83-
return (loss, per_example_loss, logits)
84+
return (loss, per_example_loss, logits, probabilities)
8485

8586

8687
def model_fn_builder(num_labels, learning_rate, num_train_steps,
@@ -101,7 +102,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
101102

102103
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
103104

104-
(total_loss, per_example_loss, logits) = create_model(
105+
(total_loss, per_example_loss, logits, probabilities) = create_model(
105106
is_training, input_ids, input_mask, segment_ids, label_ids, num_labels,
106107
bert_hub_module_handle)
107108

@@ -130,8 +131,12 @@ def metric_fn(per_example_loss, label_ids, logits):
130131
mode=mode,
131132
loss=total_loss,
132133
eval_metrics=eval_metrics)
134+
elif mode == tf.estimator.ModeKeys.PREDICT:
135+
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
136+
mode=mode, predictions={"probabilities": probabilities})
133137
else:
134-
raise ValueError("Only TRAIN and EVAL modes are supported: %s" % (mode))
138+
raise ValueError(
139+
"Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode))
135140

136141
return output_spec
137142

@@ -215,7 +220,8 @@ def main(_):
215220
model_fn=model_fn,
216221
config=run_config,
217222
train_batch_size=FLAGS.train_batch_size,
218-
eval_batch_size=FLAGS.eval_batch_size)
223+
eval_batch_size=FLAGS.eval_batch_size,
224+
predict_batch_size=FLAGS.predict_batch_size)
219225

220226
if FLAGS.do_train:
221227
train_features = run_classifier.convert_examples_to_features(
@@ -265,6 +271,40 @@ def main(_):
265271
tf.logging.info(" %s = %s", key, str(result[key]))
266272
writer.write("%s = %s\n" % (key, str(result[key])))
267273

274+
if FLAGS.do_predict:
275+
predict_examples = processor.get_test_examples(FLAGS.data_dir)
276+
if FLAGS.use_tpu:
277+
# Discard batch remainder if running on TPU
278+
n = len(predict_examples)
279+
predict_examples = predict_examples[:(n - n % FLAGS.predict_batch_size)]
280+
281+
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
282+
run_classifier.file_based_convert_examples_to_features(
283+
predict_examples, label_list, FLAGS.max_seq_length, tokenizer,
284+
predict_file)
285+
286+
tf.logging.info("***** Running prediction*****")
287+
tf.logging.info(" Num examples = %d", len(predict_examples))
288+
tf.logging.info(" Batch size = %d", FLAGS.predict_batch_size)
289+
290+
predict_input_fn = run_classifier.file_based_input_fn_builder(
291+
input_file=predict_file,
292+
seq_length=FLAGS.max_seq_length,
293+
is_training=False,
294+
drop_remainder=FLAGS.use_tpu)
295+
296+
result = estimator.predict(input_fn=predict_input_fn)
297+
298+
output_predict_file = os.path.join(FLAGS.output_dir, "test_results.tsv")
299+
with tf.gfile.GFile(output_predict_file, "w") as writer:
300+
tf.logging.info("***** Predict results *****")
301+
for prediction in result:
302+
probabilities = prediction["probabilities"]
303+
output_line = "\t".join(
304+
str(class_probability)
305+
for class_probability in probabilities) + "\n"
306+
writer.write(output_line)
307+
268308

269309
if __name__ == "__main__":
270310
flags.mark_flag_as_required("data_dir")

tokenization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def _is_control(char):
378378
if char == "\t" or char == "\n" or char == "\r":
379379
return False
380380
cat = unicodedata.category(char)
381-
if cat.startswith("C"):
381+
if cat in ("Cc", "Cf"):
382382
return True
383383
return False
384384

tokenization_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def test_is_control(self):
121121
self.assertFalse(tokenization._is_control(u" "))
122122
self.assertFalse(tokenization._is_control(u"\t"))
123123
self.assertFalse(tokenization._is_control(u"\r"))
124+
self.assertFalse(tokenization._is_control(u"\U0001F4A9"))
124125

125126
def test_is_punctuation(self):
126127
self.assertTrue(tokenization._is_punctuation(u"-"))

0 commit comments

Comments
 (0)