55import os
66from itertools import groupby
77import numpy as np
8- import copy
98import kenlm
109import multiprocessing
1110
@@ -73,25 +72,40 @@ def word_count(self, sentence):
7372 return len (words )
7473
7574 # execute evaluation
76- def __call__ (self , sentence ):
75+ def __call__ (self , sentence , log = False ):
76+ """
77+ Evaluation function
78+
79+ :param sentence: The input sentence for evalutation
80+ :type sentence: basestring
81+ :param log: Whether return the score in log representation.
82+ :type log: bool
83+ :return: Evaluation score, in the decimal or log.
84+ :rtype: float
85+ """
7786 lm = self .language_model_score (sentence )
7887 word_cnt = self .word_count (sentence )
79- score = np .power (lm , self ._alpha ) \
80- * np .power (word_cnt , self ._beta )
88+ if log == False :
89+ score = np .power (lm , self ._alpha ) \
90+ * np .power (word_cnt , self ._beta )
91+ else :
92+ score = self ._alpha * np .log (lm ) \
93+ + self ._beta * np .log (word_cnt )
8194 return score
8295
8396
8497def ctc_beam_search_decoder (probs_seq ,
8598 beam_size ,
8699 vocabulary ,
87100 blank_id = 0 ,
101+ cutoff_prob = 1.0 ,
88102 ext_scoring_func = None ,
89103 nproc = False ):
90104 '''
91105 Beam search decoder for CTC-trained network, using beam search with width
92106 beam_size to find many paths to one label, return beam_size labels in
93- the order of probabilities. The implementation is based on Prefix Beam
94- Search(https://arxiv.org/abs/1408.2873), and the unclear part is
107+ the descending order of probabilities. The implementation is based on Prefix
108+ Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
95109 redesigned, need to be verified.
96110
97111 :param probs_seq: 2-D list with length num_time_steps, each element
@@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
102116 :type beam_size: int
103117 :param vocabulary: Vocabulary list.
104118 :type vocabulary: list
119+ :param blank_id: ID of blank, default 0.
120+ :type blank_id: int
121+ :param cutoff_prob: Cutoff probability in pruning,
122+ default 1.0, no pruning.
123+ :type cutoff_prob: float
105124 :param ext_scoring_func: External defined scoring function for
106125 partially decoded sentence, e.g. word count
107126 and language model.
108127 :type external_scoring_function: function
109- :param blank_id: id of blank, default 0.
110- :type blank_id: int
111128 :param nproc: Whether the decoder used in multiprocesses.
112129 :type nproc: bool
113- :return: Decoding log probability and result string .
130+ :return: Decoding log probabilities and result sentences in descending order .
114131 :rtype: list
115132
116133 '''
117134 # dimension check
118135 for prob_list in probs_seq :
119136 if not len (prob_list ) == len (vocabulary ) + 1 :
120- raise ValueError ("probs dimension mismatchedd with vocabulary" )
137+ raise ValueError ("probs dimension mismatched with vocabulary" )
121138 num_time_steps = len (probs_seq )
122139
123140 # blank_id check
@@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
137154 probs_b_prev , probs_nb_prev = {'\t ' : 1.0 }, {'\t ' : 0.0 }
138155
139156 ## extend prefix in loop
140- for time_step in range (num_time_steps ):
157+ for time_step in xrange (num_time_steps ):
141158 # the set containing candidate prefixes
142159 prefix_set_next = {}
143160 probs_b_cur , probs_nb_cur = {}, {}
161+ prob = probs_seq [time_step ]
162+ prob_idx = [[i , prob [i ]] for i in xrange (len (prob ))]
163+ cutoff_len = len (prob_idx )
164+ #If pruning is enabled
165+ if (cutoff_prob < 1.0 ):
166+ prob_idx = sorted (prob_idx , key = lambda asd : asd [1 ], reverse = True )
167+ cutoff_len = 0
168+ cum_prob = 0.0
169+ for i in xrange (len (prob_idx )):
170+ cum_prob += prob_idx [i ][1 ]
171+ cutoff_len += 1
172+ if cum_prob >= cutoff_prob :
173+ break
174+ prob_idx = prob_idx [0 :cutoff_len ]
175+
144176 for l in prefix_set_prev :
145- prob = probs_seq [time_step ]
146177 if not prefix_set_next .has_key (l ):
147178 probs_b_cur [l ], probs_nb_cur [l ] = 0.0 , 0.0
148179
149- # extend prefix by travering vocabulary
150- for c in range (0 , probs_dim ):
180+ # extend prefix by travering prob_idx
181+ for index in xrange (cutoff_len ):
182+ c , prob_c = prob_idx [index ][0 ], prob_idx [index ][1 ]
183+
151184 if c == blank_id :
152- probs_b_cur [l ] += prob [ c ] * (
185+ probs_b_cur [l ] += prob_c * (
153186 probs_b_prev [l ] + probs_nb_prev [l ])
154187 else :
155188 last_char = l [- 1 ]
@@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
159192 probs_b_cur [l_plus ], probs_nb_cur [l_plus ] = 0.0 , 0.0
160193
161194 if new_char == last_char :
162- probs_nb_cur [l_plus ] += prob [ c ] * probs_b_prev [l ]
163- probs_nb_cur [l ] += prob [ c ] * probs_nb_prev [l ]
195+ probs_nb_cur [l_plus ] += prob_c * probs_b_prev [l ]
196+ probs_nb_cur [l ] += prob_c * probs_nb_prev [l ]
164197 elif new_char == ' ' :
165198 if (ext_scoring_func is None ) or (len (l ) == 1 ):
166199 score = 1.0
167200 else :
168201 prefix = l [1 :]
169202 score = ext_scoring_func (prefix )
170- probs_nb_cur [l_plus ] += score * prob [ c ] * (
203+ probs_nb_cur [l_plus ] += score * prob_c * (
171204 probs_b_prev [l ] + probs_nb_prev [l ])
172205 else :
173- probs_nb_cur [l_plus ] += prob [ c ] * (
206+ probs_nb_cur [l_plus ] += prob_c * (
174207 probs_b_prev [l ] + probs_nb_prev [l ])
175208 # add l_plus into prefix_set_next
176209 prefix_set_next [l_plus ] = probs_nb_cur [
@@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
203236 beam_size ,
204237 vocabulary ,
205238 blank_id = 0 ,
239+ cutoff_prob = 1.0 ,
206240 ext_scoring_func = None ,
207241 num_processes = None ):
208242 '''
@@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
216250 :type beam_size: int
217251 :param vocabulary: Vocabulary list.
218252 :type vocabulary: list
253+ :param blank_id: ID of blank, default 0.
254+ :type blank_id: int
255+ :param cutoff_prob: Cutoff probability in pruning,
256+ default 0, no pruning.
257+ :type cutoff_prob: float
219258 :param ext_scoring_func: External defined scoring function for
220259 partially decoded sentence, e.g. word count
221260 and language model.
222261 :type external_scoring_function: function
223- :param blank_id: id of blank, default 0.
224- :type blank_id: int
225262 :param num_processes: Number of processes, default None, equal to the
226263 number of CPUs.
227264 :type num_processes: int
228- :return: Decoding log probability and result string .
265+ :return: Decoding log probabilities and result sentences in descending order .
229266 :rtype: list
230267
231268 '''
@@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
243280 pool = multiprocessing .Pool (processes = num_processes )
244281 results = []
245282 for i , probs_list in enumerate (probs_split ):
246- args = (probs_list , beam_size , vocabulary , blank_id , None , nproc )
283+ args = (probs_list , beam_size , vocabulary , blank_id , cutoff_prob , None ,
284+ nproc )
247285 results .append (pool .apply_async (ctc_beam_search_decoder , args ))
248286
249287 pool .close ()
0 commit comments