Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Word Error Rate for Speech Recognition #1242

Merged
merged 1 commit into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensor2tensor/data_generators/speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ def preprocess_example(self, example, mode, hparams):

def eval_metrics(self):
defaults = super(SpeechRecognitionProblem, self).eval_metrics()
return defaults + [metrics.Metrics.EDIT_DISTANCE]
return defaults + [
metrics.Metrics.EDIT_DISTANCE,
metrics.Metrics.WORD_ERROR_RATE
]
64 changes: 64 additions & 0 deletions tensor2tensor/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Metrics(object):
ROUGE_2_F = "rouge_2_fscore"
ROUGE_L_F = "rouge_L_fscore"
EDIT_DISTANCE = "edit_distance"
WORD_ERROR_RATE = "word_error_rate"
SET_PRECISION = "set_precision"
SET_RECALL = "set_recall"
SOFTMAX_CROSS_ENTROPY_ONE_HOT = "softmax_cross_entropy_one_hot"
Expand Down Expand Up @@ -669,6 +670,68 @@ def metric_means():
return metric_accum, metric_means


def word_error_rate(raw_predictions, labels, lookup=None,
weights_fn=common_layers.weights_nonzero):
"""
:param raw_predictions:
:param labels:
:param lookup:
A tf.constant mapping indices to output tokens.
:param weights_fn:
:return:
The word error rate.
"""

def from_tokens(raw, lookup_):
gathered = tf.gather(lookup_, tf.cast(raw, tf.int32))
joined = tf.regex_replace(tf.reduce_join(gathered, axis=1), b'<EOS>.*', b'')
cleaned = tf.regex_replace(joined, b'_', b' ')
tokens = tf.string_split(cleaned, ' ')
return tokens

def from_characters(raw, lookup_):
"""
Convert ascii+2 encoded codes to string-tokens.
"""
corrected = tf.bitcast(
tf.clip_by_value(
tf.subtract(raw, 2), 0, 255
), tf.uint8)

gathered = tf.gather(lookup_, tf.cast(corrected, tf.int32))[:, :, 0]
joined = tf.reduce_join(gathered, axis=1)
cleaned = tf.regex_replace(joined, b'\0', b'')
tokens = tf.string_split(cleaned, ' ')
return tokens

if lookup is None:
lookup = tf.constant([chr(i) for i in range(256)])
convert_fn = from_characters
else:
convert_fn = from_tokens

if weights_fn is not common_layers.weights_nonzero:
raise ValueError("Only weights_nonzero can be used for this metric.")

with tf.variable_scope("word_error_rate", values=[raw_predictions, labels]):

raw_predictions = tf.squeeze(
tf.argmax(raw_predictions, axis=-1), axis=(2, 3))
labels = tf.squeeze(labels, axis=(2, 3))

reference = convert_fn(labels, lookup)
predictions = convert_fn(raw_predictions, lookup)

distance = tf.reduce_sum(
tf.edit_distance(predictions, reference, normalize=False)
)
reference_length = tf.cast(
tf.size(reference.values, out_type=tf.int32), dtype=tf.float32
)

return distance / reference_length, reference_length


# Metrics are functions that take predictions and labels and return
# a tensor of metrics and a tensor of weights.
# If the function has "features" as an argument, it will receive the whole
Expand All @@ -688,6 +751,7 @@ def metric_means():
Metrics.ROUGE_2_F: rouge.rouge_2_fscore,
Metrics.ROUGE_L_F: rouge.rouge_l_fscore,
Metrics.EDIT_DISTANCE: sequence_edit_distance,
Metrics.WORD_ERROR_RATE: word_error_rate,
Metrics.SOFTMAX_CROSS_ENTROPY_ONE_HOT: softmax_cross_entropy_one_hot,
Metrics.SIGMOID_ACCURACY_ONE_HOT: sigmoid_accuracy_one_hot,
Metrics.SIGMOID_RECALL_ONE_HOT: sigmoid_recall_one_hot,
Expand Down
33 changes: 33 additions & 0 deletions tensor2tensor/utils/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,39 @@ def testSequenceEditDistanceMetric(self):
self.assertAlmostEqual(actual_scores, 3.0 / 13)
self.assertEqual(actual_weight, 13)

def testWordErrorRateMetric(self):

ref = np.asarray([
# a b c
[97, 34, 98, 34, 99],
[97, 34, 98, 34, 99],
[97, 34, 98, 34, 99],
[97, 34, 98, 34, 99],
])

hyp = np.asarray([
[97, 34, 98, 34, 99], # a b c
[97, 34, 98, 0, 0], # a b
[97, 34, 98, 34, 100], # a b d
[0, 0, 0, 0, 0] # empty
])

labels = np.reshape(ref, ref.shape + (1, 1))
predictions = np.zeros((len(ref), np.max([len(s) for s in hyp]), 1, 1, 256))

for i, sample in enumerate(hyp):
for j, idx in enumerate(sample):
predictions[i, j, 0, 0, idx] = 1

with self.test_session() as session:
actual_wer, actual_ref_len = session.run(
metrics.word_error_rate(predictions, labels)
)

expected_wer = 0.417
places = 3
self.assertAlmostEqual(round(actual_wer, places), expected_wer, places)

def testNegativeLogPerplexity(self):
predictions = np.random.randint(4, size=(12, 12, 12, 1))
targets = np.random.randint(4, size=(12, 12, 12, 1))
Expand Down