Skip to content

Commit 504b15c

Browse files
author
Yibing Liu
committed
add ctc beam search decoder
1 parent 367e123 commit 504b15c

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
## This is a prototype of ctc beam search decoder
2+
3+
import copy
4+
import random
5+
import numpy as np
6+
7+
# vocab = English characters + blank + space
8+
#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)]
9+
10+
vocab = [chr(97), chr(98), chr(99), chr(100)]+[' ', '-']
11+
12+
def ids_str2list(ids_str):
13+
ids_str = ids_str.split(' ')
14+
ids_list = [int(elem) for elem in ids_str]
15+
return ids_list
16+
17+
def ids_list2str(ids_list):
18+
ids_str = [str(elem) for elem in ids_list]
19+
ids_str = ' '.join(ids_str)
20+
return ids_str
21+
22+
def ctc_beam_search_decoder(
23+
input_probs_matrix,
24+
beam_size,
25+
lang_model=None,
26+
name=None,
27+
alpha=1.0,
28+
beta=1.0,
29+
blank_id=0,
30+
space_id=1,
31+
num_results_per_sample=None):
32+
33+
'''
34+
beam search decoder for CTC-trained network, called outside of the recurrent group.
35+
adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873.
36+
'''
37+
if num_results_per_sample is None:
38+
num_results_per_sample = beam_size
39+
assert num_results_per_sample <= beam_size
40+
41+
max_time_steps = len(input_probs_matrix)
42+
assert max_time_steps > 0
43+
44+
vocab_dim = len(input_probs_matrix[0])
45+
assert blank_id < vocab_dim
46+
assert space_id < vocab_dim
47+
48+
# initialize
49+
start_id = -1
50+
prefix_set_prev = {str(start_id):1.0}
51+
probs_b, probs_nb = {str(start_id):1.0}, {str(start_id):0.0}
52+
53+
# extend prefix in loop
54+
for time_step in range(max_time_steps):
55+
print "\ntime_step = %d" % (time_step+1)
56+
prefix_set_next = {}
57+
probs_b_cur, probs_nb_cur = {}, {}
58+
for (l, prob_) in prefix_set_prev.items():
59+
print "l = %s\t%f" % (l, prob_)
60+
prob = input_probs_matrix[time_step]
61+
62+
# convert ids in string to list
63+
ids_list = ids_str2list(l)
64+
end_id = ids_list[-1]
65+
if not probs_b_cur.has_key(l):
66+
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
67+
68+
# extend prefix
69+
for c in range(0, vocab_dim):
70+
if c == blank_id:
71+
probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l])
72+
else:
73+
l_plus = l + ' ' + str(c)
74+
if not probs_b_cur.has_key(l_plus):
75+
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
76+
77+
if c == end_id:
78+
probs_nb_cur[l_plus] += prob[c] * probs_b[l]
79+
probs_nb_cur[l] += prob[c] * probs_nb[l]
80+
elif c == space_id:
81+
lm = 1 if lang_model is None else np.power(lang_model(ids_list), alpha)
82+
probs_nb_cur[l_plus] += lm * prob[c] * (probs_b[l]+probs_nb[l])
83+
else:
84+
probs_nb_cur[l_plus] += prob[c] * (probs_b[l]+probs_nb[l])
85+
prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus]
86+
87+
print "l_plus: %s\t b=%f\tnb=%f\tP=%f" % (l_plus, probs_b_cur[l_plus], probs_nb_cur[l_plus], prefix_set_next[l_plus])
88+
prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l]
89+
print "l: %s\t b=%f\tnb=%f\tP=%f" % (l, probs_b_cur[l], probs_nb_cur[l], prefix_set_next[l])
90+
probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(probs_nb_cur)
91+
92+
prefix_set_prev = sorted(prefix_set_next.iteritems(), key = lambda asd:asd[1], reverse=True)
93+
94+
if beam_size < len(prefix_set_prev):
95+
prefix_set_prev = prefix_set_prev[:beam_size]
96+
prefix_set_prev = dict(prefix_set_prev)
97+
98+
beam_result = []
99+
for (seq, prob) in prefix_set_prev.items():
100+
if prob > 0.0:
101+
ids_list = ids_str2list(seq)
102+
log_prob = np.log(prob)
103+
beam_result.append([log_prob, ids_list[1:]])
104+
105+
beam_result = sorted(beam_result, key = lambda asd:asd[0], reverse=True)
106+
if num_results_per_sample < beam_size:
107+
beam_result = beam_result[:num_results_per_sample]
108+
return beam_result
109+
110+
def language_model(input):
111+
# TODO
112+
return random.uniform(0, 1)
113+
114+
def ctc_net(input, size_vocab):
115+
size = len(vocab)
116+
# prob = np.array([random.uniform(0, 1) for i in range(0, size)])
117+
prob = np.array([1.0 for i in range(0, size)])
118+
prob = prob/prob.sum()
119+
return prob
120+
121+
def main():
122+
123+
input_probs_matrix = [[0.1, 0.3, 0.6],
124+
[0.2, 0.1, 0.7],
125+
[0.5, 0.2, 0.3]]
126+
127+
prob_matrix = [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
128+
[0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
129+
[0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
130+
[0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
131+
[0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
132+
# Random entry added in at time=5
133+
#[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]
134+
]
135+
136+
beam_result = ctc_beam_search_decoder(
137+
input_probs_matrix=prob_matrix,
138+
beam_size=2,
139+
blank_id=5,
140+
)
141+
def ids2str(ids_list):
142+
ids_str = ''
143+
for ids in ids_list:
144+
ids_str += vocab[ids]
145+
return ids_str
146+
147+
print "\nbeam search output:"
148+
for result in beam_result:
149+
print result[0], ids2str(result[1])
150+
151+
if __name__ == '__main__':
152+
main()

0 commit comments

Comments
 (0)