|
| 1 | +## This is a prototype of ctc beam search decoder |
| 2 | + |
| 3 | +import copy |
| 4 | +import random |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +# vocab = English characters + blank + space |
| 8 | +#vocab = ['-', ' '] + [chr(i) for i in range(97, 123)] |
| 9 | + |
| 10 | +vocab = [chr(97), chr(98), chr(99), chr(100)]+[' ', '-'] |
| 11 | + |
| 12 | +def ids_str2list(ids_str): |
| 13 | + ids_str = ids_str.split(' ') |
| 14 | + ids_list = [int(elem) for elem in ids_str] |
| 15 | + return ids_list |
| 16 | + |
| 17 | +def ids_list2str(ids_list): |
| 18 | + ids_str = [str(elem) for elem in ids_list] |
| 19 | + ids_str = ' '.join(ids_str) |
| 20 | + return ids_str |
| 21 | + |
| 22 | +def ctc_beam_search_decoder( |
| 23 | + input_probs_matrix, |
| 24 | + beam_size, |
| 25 | + lang_model=None, |
| 26 | + name=None, |
| 27 | + alpha=1.0, |
| 28 | + beta=1.0, |
| 29 | + blank_id=0, |
| 30 | + space_id=1, |
| 31 | + num_results_per_sample=None): |
| 32 | + |
| 33 | + ''' |
| 34 | + beam search decoder for CTC-trained network, called outside of the recurrent group. |
| 35 | + adapted from Algorithm 1 in https://arxiv.org/abs/1408.2873. |
| 36 | + ''' |
| 37 | + if num_results_per_sample is None: |
| 38 | + num_results_per_sample = beam_size |
| 39 | + assert num_results_per_sample <= beam_size |
| 40 | + |
| 41 | + max_time_steps = len(input_probs_matrix) |
| 42 | + assert max_time_steps > 0 |
| 43 | + |
| 44 | + vocab_dim = len(input_probs_matrix[0]) |
| 45 | + assert blank_id < vocab_dim |
| 46 | + assert space_id < vocab_dim |
| 47 | + |
| 48 | + # initialize |
| 49 | + start_id = -1 |
| 50 | + prefix_set_prev = {str(start_id):1.0} |
| 51 | + probs_b, probs_nb = {str(start_id):1.0}, {str(start_id):0.0} |
| 52 | + |
| 53 | + # extend prefix in loop |
| 54 | + for time_step in range(max_time_steps): |
| 55 | + print "\ntime_step = %d" % (time_step+1) |
| 56 | + prefix_set_next = {} |
| 57 | + probs_b_cur, probs_nb_cur = {}, {} |
| 58 | + for (l, prob_) in prefix_set_prev.items(): |
| 59 | + print "l = %s\t%f" % (l, prob_) |
| 60 | + prob = input_probs_matrix[time_step] |
| 61 | + |
| 62 | + # convert ids in string to list |
| 63 | + ids_list = ids_str2list(l) |
| 64 | + end_id = ids_list[-1] |
| 65 | + if not probs_b_cur.has_key(l): |
| 66 | + probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0 |
| 67 | + |
| 68 | + # extend prefix |
| 69 | + for c in range(0, vocab_dim): |
| 70 | + if c == blank_id: |
| 71 | + probs_b_cur[l] += prob[c] * (probs_b[l] + probs_nb[l]) |
| 72 | + else: |
| 73 | + l_plus = l + ' ' + str(c) |
| 74 | + if not probs_b_cur.has_key(l_plus): |
| 75 | + probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0 |
| 76 | + |
| 77 | + if c == end_id: |
| 78 | + probs_nb_cur[l_plus] += prob[c] * probs_b[l] |
| 79 | + probs_nb_cur[l] += prob[c] * probs_nb[l] |
| 80 | + elif c == space_id: |
| 81 | + lm = 1 if lang_model is None else np.power(lang_model(ids_list), alpha) |
| 82 | + probs_nb_cur[l_plus] += lm * prob[c] * (probs_b[l]+probs_nb[l]) |
| 83 | + else: |
| 84 | + probs_nb_cur[l_plus] += prob[c] * (probs_b[l]+probs_nb[l]) |
| 85 | + prefix_set_next[l_plus] = probs_nb_cur[l_plus]+probs_b_cur[l_plus] |
| 86 | + |
| 87 | + print "l_plus: %s\t b=%f\tnb=%f\tP=%f" % (l_plus, probs_b_cur[l_plus], probs_nb_cur[l_plus], prefix_set_next[l_plus]) |
| 88 | + prefix_set_next[l] = probs_b_cur[l]+probs_nb_cur[l] |
| 89 | + print "l: %s\t b=%f\tnb=%f\tP=%f" % (l, probs_b_cur[l], probs_nb_cur[l], prefix_set_next[l]) |
| 90 | + probs_b, probs_nb = copy.deepcopy(probs_b_cur), copy.deepcopy(probs_nb_cur) |
| 91 | + |
| 92 | + prefix_set_prev = sorted(prefix_set_next.iteritems(), key = lambda asd:asd[1], reverse=True) |
| 93 | + |
| 94 | + if beam_size < len(prefix_set_prev): |
| 95 | + prefix_set_prev = prefix_set_prev[:beam_size] |
| 96 | + prefix_set_prev = dict(prefix_set_prev) |
| 97 | + |
| 98 | + beam_result = [] |
| 99 | + for (seq, prob) in prefix_set_prev.items(): |
| 100 | + if prob > 0.0: |
| 101 | + ids_list = ids_str2list(seq) |
| 102 | + log_prob = np.log(prob) |
| 103 | + beam_result.append([log_prob, ids_list[1:]]) |
| 104 | + |
| 105 | + beam_result = sorted(beam_result, key = lambda asd:asd[0], reverse=True) |
| 106 | + if num_results_per_sample < beam_size: |
| 107 | + beam_result = beam_result[:num_results_per_sample] |
| 108 | + return beam_result |
| 109 | + |
| 110 | +def language_model(input): |
| 111 | + # TODO |
| 112 | + return random.uniform(0, 1) |
| 113 | + |
| 114 | +def ctc_net(input, size_vocab): |
| 115 | + size = len(vocab) |
| 116 | + # prob = np.array([random.uniform(0, 1) for i in range(0, size)]) |
| 117 | + prob = np.array([1.0 for i in range(0, size)]) |
| 118 | + prob = prob/prob.sum() |
| 119 | + return prob |
| 120 | + |
| 121 | +def main(): |
| 122 | + |
| 123 | + input_probs_matrix = [[0.1, 0.3, 0.6], |
| 124 | + [0.2, 0.1, 0.7], |
| 125 | + [0.5, 0.2, 0.3]] |
| 126 | + |
| 127 | + prob_matrix = [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908], |
| 128 | + [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517], |
| 129 | + [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763], |
| 130 | + [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655], |
| 131 | + [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878], |
| 132 | + # Random entry added in at time=5 |
| 133 | + #[0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671] |
| 134 | + ] |
| 135 | + |
| 136 | + beam_result = ctc_beam_search_decoder( |
| 137 | + input_probs_matrix=prob_matrix, |
| 138 | + beam_size=2, |
| 139 | + blank_id=5, |
| 140 | + ) |
| 141 | + def ids2str(ids_list): |
| 142 | + ids_str = '' |
| 143 | + for ids in ids_list: |
| 144 | + ids_str += vocab[ids] |
| 145 | + return ids_str |
| 146 | + |
| 147 | + print "\nbeam search output:" |
| 148 | + for result in beam_result: |
| 149 | + print result[0], ids2str(result[1]) |
| 150 | + |
| 151 | +if __name__ == '__main__': |
| 152 | + main() |
0 commit comments