-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_ctc_beam_search_decoder.py
74 lines (62 loc) · 2.98 KB
/
test_ctc_beam_search_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from ctc_beam_search_decoder import *
import time
vocab_list = ['\'', ' ']+[chr(i) for i in range(97, 101)]
#vocab_list = ['\'', ' ']+[chr(i) for i in range(97, 123)]
def generate_probs(num_time_steps, probs_dim):
probs_mat = np.random.random(size=(num_time_steps, probs_dim))
probs_mat = [probs_mat[index]/sum(probs_mat[index]) for index in range(num_time_steps)]
return probs_mat
def test_beam_search_decoder():
max_time_steps = 6
probs_dim = len(vocab_list)+1
beam_size = 20
num_results_per_sample = 1
input_prob_matrix_0 = np.asarray(generate_probs(max_time_steps, probs_dim), dtype=np.float32)
print(input_prob_matrix_0)
# Add arbitrary offset - this is fine
input_log_prob_matrix_0 = np.log(input_prob_matrix_0) #+ 2.0
# len max_time_steps array of batch_size x depth matrices
inputs = ([
input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(max_time_steps)]
)
inputs_t = [ops.convert_to_tensor(x) for x in inputs]
inputs_t = array_ops.stack(inputs_t)
# run CTC beam search decoder in tensorflow
with tf.Session() as sess:
decoded, log_probabilities = tf.nn.ctc_beam_search_decoder(inputs_t,
[max_time_steps],
beam_width=beam_size,
top_paths=num_results_per_sample,
merge_repeated=False)
tf_decoded = sess.run(decoded)
tf_log_probs = sess.run(log_probabilities)
# run original CTC beam search decoder
beam_result = ctc_beam_search_decoder(
probs_seq=input_prob_matrix_0,
beam_size=beam_size,
vocabulary=vocab_list,
blank_id=len(vocab_list),
cutoff_prob=1.0,
)
# run log- CTC beam search decoder
beam_result_log = ctc_beam_search_decoder_log(
probs_seq=input_prob_matrix_0,
beam_size=beam_size,
vocabulary=vocab_list,
blank_id=len(vocab_list),
cutoff_prob=1.0,
)
# compare decoding result
print("{tf-decoder log probs} \t {org-decoder log probs} \t{log-decoder log probs}: {tf_decoder result} {org_decoder result} {log-decoder result}")
for index in range(num_results_per_sample):
tf_result = ''.join([vocab_list[i] for i in tf_decoded[index].values])
print(('%6f\t%f\t%f: ') % (tf_log_probs[0][index], beam_result[index][0], beam_result_log[index][0]),
tf_result,'\t', beam_result[index][1], '\t', beam_result_log[index][1])
if __name__ == '__main__':
test_beam_search_decoder()