Skip to content

Commit eb046c0

Browse files
stefan-falkkpe
authored andcommitted
Word Error Rate for Speech Recognition (tensorflow#1242)
1 parent ed55fc5 commit eb046c0

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

tensor2tensor/data_generators/speech_recognition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,7 @@ def preprocess_example(self, example, mode, hparams):
139139

140140
def eval_metrics(self):
141141
defaults = super(SpeechRecognitionProblem, self).eval_metrics()
142-
return defaults + [metrics.Metrics.EDIT_DISTANCE]
142+
return defaults + [
143+
metrics.Metrics.EDIT_DISTANCE,
144+
metrics.Metrics.WORD_ERROR_RATE
145+
]

tensor2tensor/utils/metrics.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Metrics(object):
4747
ROUGE_2_F = "rouge_2_fscore"
4848
ROUGE_L_F = "rouge_L_fscore"
4949
EDIT_DISTANCE = "edit_distance"
50+
WORD_ERROR_RATE = "word_error_rate"
5051
SET_PRECISION = "set_precision"
5152
SET_RECALL = "set_recall"
5253
SOFTMAX_CROSS_ENTROPY_ONE_HOT = "softmax_cross_entropy_one_hot"
@@ -680,6 +681,68 @@ def metric_means():
680681
return metric_accum, metric_means
681682

682683

684+
def word_error_rate(raw_predictions, labels, lookup=None,
685+
weights_fn=common_layers.weights_nonzero):
686+
"""
687+
:param raw_predictions:
688+
:param labels:
689+
:param lookup:
690+
A tf.constant mapping indices to output tokens.
691+
:param weights_fn:
692+
:return:
693+
The word error rate.
694+
"""
695+
696+
def from_tokens(raw, lookup_):
697+
gathered = tf.gather(lookup_, tf.cast(raw, tf.int32))
698+
joined = tf.regex_replace(tf.reduce_join(gathered, axis=1), b'<EOS>.*', b'')
699+
cleaned = tf.regex_replace(joined, b'_', b' ')
700+
tokens = tf.string_split(cleaned, ' ')
701+
return tokens
702+
703+
def from_characters(raw, lookup_):
704+
"""
705+
Convert ascii+2 encoded codes to string-tokens.
706+
"""
707+
corrected = tf.bitcast(
708+
tf.clip_by_value(
709+
tf.subtract(raw, 2), 0, 255
710+
), tf.uint8)
711+
712+
gathered = tf.gather(lookup_, tf.cast(corrected, tf.int32))[:, :, 0]
713+
joined = tf.reduce_join(gathered, axis=1)
714+
cleaned = tf.regex_replace(joined, b'\0', b'')
715+
tokens = tf.string_split(cleaned, ' ')
716+
return tokens
717+
718+
if lookup is None:
719+
lookup = tf.constant([chr(i) for i in range(256)])
720+
convert_fn = from_characters
721+
else:
722+
convert_fn = from_tokens
723+
724+
if weights_fn is not common_layers.weights_nonzero:
725+
raise ValueError("Only weights_nonzero can be used for this metric.")
726+
727+
with tf.variable_scope("word_error_rate", values=[raw_predictions, labels]):
728+
729+
raw_predictions = tf.squeeze(
730+
tf.argmax(raw_predictions, axis=-1), axis=(2, 3))
731+
labels = tf.squeeze(labels, axis=(2, 3))
732+
733+
reference = convert_fn(labels, lookup)
734+
predictions = convert_fn(raw_predictions, lookup)
735+
736+
distance = tf.reduce_sum(
737+
tf.edit_distance(predictions, reference, normalize=False)
738+
)
739+
reference_length = tf.cast(
740+
tf.size(reference.values, out_type=tf.int32), dtype=tf.float32
741+
)
742+
743+
return distance / reference_length, reference_length
744+
745+
683746
# Metrics are functions that take predictions and labels and return
684747
# a tensor of metrics and a tensor of weights.
685748
# If the function has "features" as an argument, it will receive the whole
@@ -699,6 +762,7 @@ def metric_means():
699762
Metrics.ROUGE_2_F: rouge.rouge_2_fscore,
700763
Metrics.ROUGE_L_F: rouge.rouge_l_fscore,
701764
Metrics.EDIT_DISTANCE: sequence_edit_distance,
765+
Metrics.WORD_ERROR_RATE: word_error_rate,
702766
Metrics.SOFTMAX_CROSS_ENTROPY_ONE_HOT: softmax_cross_entropy_one_hot,
703767
Metrics.SIGMOID_ACCURACY_ONE_HOT: sigmoid_accuracy_one_hot,
704768
Metrics.SIGMOID_RECALL_ONE_HOT: sigmoid_recall_one_hot,

tensor2tensor/utils/metrics_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,39 @@ def testSequenceEditDistanceMetric(self):
105105
self.assertAlmostEqual(actual_scores, 3.0 / 13)
106106
self.assertEqual(actual_weight, 13)
107107

108+
def testWordErrorRateMetric(self):
109+
110+
ref = np.asarray([
111+
# a b c
112+
[97, 34, 98, 34, 99],
113+
[97, 34, 98, 34, 99],
114+
[97, 34, 98, 34, 99],
115+
[97, 34, 98, 34, 99],
116+
])
117+
118+
hyp = np.asarray([
119+
[97, 34, 98, 34, 99], # a b c
120+
[97, 34, 98, 0, 0], # a b
121+
[97, 34, 98, 34, 100], # a b d
122+
[0, 0, 0, 0, 0] # empty
123+
])
124+
125+
labels = np.reshape(ref, ref.shape + (1, 1))
126+
predictions = np.zeros((len(ref), np.max([len(s) for s in hyp]), 1, 1, 256))
127+
128+
for i, sample in enumerate(hyp):
129+
for j, idx in enumerate(sample):
130+
predictions[i, j, 0, 0, idx] = 1
131+
132+
with self.test_session() as session:
133+
actual_wer, actual_ref_len = session.run(
134+
metrics.word_error_rate(predictions, labels)
135+
)
136+
137+
expected_wer = 0.417
138+
places = 3
139+
self.assertAlmostEqual(round(actual_wer, places), expected_wer, places)
140+
108141
def testNegativeLogPerplexity(self):
109142
predictions = np.random.randint(4, size=(12, 12, 12, 1))
110143
targets = np.random.randint(4, size=(12, 12, 12, 1))

0 commit comments

Comments
 (0)