@@ -47,6 +47,7 @@ class Metrics(object):
47
47
ROUGE_2_F = "rouge_2_fscore"
48
48
ROUGE_L_F = "rouge_L_fscore"
49
49
EDIT_DISTANCE = "edit_distance"
50
+ WORD_ERROR_RATE = "word_error_rate"
50
51
SET_PRECISION = "set_precision"
51
52
SET_RECALL = "set_recall"
52
53
SOFTMAX_CROSS_ENTROPY_ONE_HOT = "softmax_cross_entropy_one_hot"
@@ -680,6 +681,68 @@ def metric_means():
680
681
return metric_accum , metric_means
681
682
682
683
684
+ def word_error_rate (raw_predictions , labels , lookup = None ,
685
+ weights_fn = common_layers .weights_nonzero ):
686
+ """
687
+ :param raw_predictions:
688
+ :param labels:
689
+ :param lookup:
690
+ A tf.constant mapping indices to output tokens.
691
+ :param weights_fn:
692
+ :return:
693
+ The word error rate.
694
+ """
695
+
696
+ def from_tokens (raw , lookup_ ):
697
+ gathered = tf .gather (lookup_ , tf .cast (raw , tf .int32 ))
698
+ joined = tf .regex_replace (tf .reduce_join (gathered , axis = 1 ), b'<EOS>.*' , b'' )
699
+ cleaned = tf .regex_replace (joined , b'_' , b' ' )
700
+ tokens = tf .string_split (cleaned , ' ' )
701
+ return tokens
702
+
703
+ def from_characters (raw , lookup_ ):
704
+ """
705
+ Convert ascii+2 encoded codes to string-tokens.
706
+ """
707
+ corrected = tf .bitcast (
708
+ tf .clip_by_value (
709
+ tf .subtract (raw , 2 ), 0 , 255
710
+ ), tf .uint8 )
711
+
712
+ gathered = tf .gather (lookup_ , tf .cast (corrected , tf .int32 ))[:, :, 0 ]
713
+ joined = tf .reduce_join (gathered , axis = 1 )
714
+ cleaned = tf .regex_replace (joined , b'\0 ' , b'' )
715
+ tokens = tf .string_split (cleaned , ' ' )
716
+ return tokens
717
+
718
+ if lookup is None :
719
+ lookup = tf .constant ([chr (i ) for i in range (256 )])
720
+ convert_fn = from_characters
721
+ else :
722
+ convert_fn = from_tokens
723
+
724
+ if weights_fn is not common_layers .weights_nonzero :
725
+ raise ValueError ("Only weights_nonzero can be used for this metric." )
726
+
727
+ with tf .variable_scope ("word_error_rate" , values = [raw_predictions , labels ]):
728
+
729
+ raw_predictions = tf .squeeze (
730
+ tf .argmax (raw_predictions , axis = - 1 ), axis = (2 , 3 ))
731
+ labels = tf .squeeze (labels , axis = (2 , 3 ))
732
+
733
+ reference = convert_fn (labels , lookup )
734
+ predictions = convert_fn (raw_predictions , lookup )
735
+
736
+ distance = tf .reduce_sum (
737
+ tf .edit_distance (predictions , reference , normalize = False )
738
+ )
739
+ reference_length = tf .cast (
740
+ tf .size (reference .values , out_type = tf .int32 ), dtype = tf .float32
741
+ )
742
+
743
+ return distance / reference_length , reference_length
744
+
745
+
683
746
# Metrics are functions that take predictions and labels and return
684
747
# a tensor of metrics and a tensor of weights.
685
748
# If the function has "features" as an argument, it will receive the whole
@@ -699,6 +762,7 @@ def metric_means():
699
762
Metrics .ROUGE_2_F : rouge .rouge_2_fscore ,
700
763
Metrics .ROUGE_L_F : rouge .rouge_l_fscore ,
701
764
Metrics .EDIT_DISTANCE : sequence_edit_distance ,
765
+ Metrics .WORD_ERROR_RATE : word_error_rate ,
702
766
Metrics .SOFTMAX_CROSS_ENTROPY_ONE_HOT : softmax_cross_entropy_one_hot ,
703
767
Metrics .SIGMOID_ACCURACY_ONE_HOT : sigmoid_accuracy_one_hot ,
704
768
Metrics .SIGMOID_RECALL_ONE_HOT : sigmoid_recall_one_hot ,
0 commit comments