Skip to content

Commit 08203ee

Browse files
author
Yibing Liu
committed
final refining on old data provider: enable pruning & add evaluation & code cleanup
1 parent 0fa063e commit 08203ee

File tree

4 files changed

+339
-72
lines changed

4 files changed

+339
-72
lines changed

deep_speech_2/decoder.py

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
from itertools import groupby
77
import numpy as np
8-
import copy
98
import kenlm
109
import multiprocessing
1110

@@ -73,25 +72,40 @@ def word_count(self, sentence):
7372
return len(words)
7473

7574
# execute evaluation
76-
def __call__(self, sentence):
75+
def __call__(self, sentence, log=False):
76+
"""
77+
Evaluation function
78+
79+
:param sentence: The input sentence for evalutation
80+
:type sentence: basestring
81+
:param log: Whether return the score in log representation.
82+
:type log: bool
83+
:return: Evaluation score, in the decimal or log.
84+
:rtype: float
85+
"""
7786
lm = self.language_model_score(sentence)
7887
word_cnt = self.word_count(sentence)
79-
score = np.power(lm, self._alpha) \
80-
* np.power(word_cnt, self._beta)
88+
if log == False:
89+
score = np.power(lm, self._alpha) \
90+
* np.power(word_cnt, self._beta)
91+
else:
92+
score = self._alpha * np.log(lm) \
93+
+ self._beta * np.log(word_cnt)
8194
return score
8295

8396

8497
def ctc_beam_search_decoder(probs_seq,
8598
beam_size,
8699
vocabulary,
87100
blank_id=0,
101+
cutoff_prob=1.0,
88102
ext_scoring_func=None,
89103
nproc=False):
90104
'''
91105
Beam search decoder for CTC-trained network, using beam search with width
92106
beam_size to find many paths to one label, return beam_size labels in
93-
the order of probabilities. The implementation is based on Prefix Beam
94-
Search(https://arxiv.org/abs/1408.2873), and the unclear part is
107+
the descending order of probabilities. The implementation is based on Prefix
108+
Beam Search(https://arxiv.org/abs/1408.2873), and the unclear part is
95109
redesigned, need to be verified.
96110
97111
:param probs_seq: 2-D list with length num_time_steps, each element
@@ -102,22 +116,25 @@ def ctc_beam_search_decoder(probs_seq,
102116
:type beam_size: int
103117
:param vocabulary: Vocabulary list.
104118
:type vocabulary: list
119+
:param blank_id: ID of blank, default 0.
120+
:type blank_id: int
121+
:param cutoff_prob: Cutoff probability in pruning,
122+
default 1.0, no pruning.
123+
:type cutoff_prob: float
105124
:param ext_scoring_func: External defined scoring function for
106125
partially decoded sentence, e.g. word count
107126
and language model.
108127
:type external_scoring_function: function
109-
:param blank_id: id of blank, default 0.
110-
:type blank_id: int
111128
:param nproc: Whether the decoder used in multiprocesses.
112129
:type nproc: bool
113-
:return: Decoding log probability and result string.
130+
:return: Decoding log probabilities and result sentences in descending order.
114131
:rtype: list
115132
116133
'''
117134
# dimension check
118135
for prob_list in probs_seq:
119136
if not len(prob_list) == len(vocabulary) + 1:
120-
raise ValueError("probs dimension mismatchedd with vocabulary")
137+
raise ValueError("probs dimension mismatched with vocabulary")
121138
num_time_steps = len(probs_seq)
122139

123140
# blank_id check
@@ -137,19 +154,35 @@ def ctc_beam_search_decoder(probs_seq,
137154
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
138155

139156
## extend prefix in loop
140-
for time_step in range(num_time_steps):
157+
for time_step in xrange(num_time_steps):
141158
# the set containing candidate prefixes
142159
prefix_set_next = {}
143160
probs_b_cur, probs_nb_cur = {}, {}
161+
prob = probs_seq[time_step]
162+
prob_idx = [[i, prob[i]] for i in xrange(len(prob))]
163+
cutoff_len = len(prob_idx)
164+
#If pruning is enabled
165+
if (cutoff_prob < 1.0):
166+
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
167+
cutoff_len = 0
168+
cum_prob = 0.0
169+
for i in xrange(len(prob_idx)):
170+
cum_prob += prob_idx[i][1]
171+
cutoff_len += 1
172+
if cum_prob >= cutoff_prob:
173+
break
174+
prob_idx = prob_idx[0:cutoff_len]
175+
144176
for l in prefix_set_prev:
145-
prob = probs_seq[time_step]
146177
if not prefix_set_next.has_key(l):
147178
probs_b_cur[l], probs_nb_cur[l] = 0.0, 0.0
148179

149-
# extend prefix by travering vocabulary
150-
for c in range(0, probs_dim):
180+
# extend prefix by travering prob_idx
181+
for index in xrange(cutoff_len):
182+
c, prob_c = prob_idx[index][0], prob_idx[index][1]
183+
151184
if c == blank_id:
152-
probs_b_cur[l] += prob[c] * (
185+
probs_b_cur[l] += prob_c * (
153186
probs_b_prev[l] + probs_nb_prev[l])
154187
else:
155188
last_char = l[-1]
@@ -159,18 +192,18 @@ def ctc_beam_search_decoder(probs_seq,
159192
probs_b_cur[l_plus], probs_nb_cur[l_plus] = 0.0, 0.0
160193

161194
if new_char == last_char:
162-
probs_nb_cur[l_plus] += prob[c] * probs_b_prev[l]
163-
probs_nb_cur[l] += prob[c] * probs_nb_prev[l]
195+
probs_nb_cur[l_plus] += prob_c * probs_b_prev[l]
196+
probs_nb_cur[l] += prob_c * probs_nb_prev[l]
164197
elif new_char == ' ':
165198
if (ext_scoring_func is None) or (len(l) == 1):
166199
score = 1.0
167200
else:
168201
prefix = l[1:]
169202
score = ext_scoring_func(prefix)
170-
probs_nb_cur[l_plus] += score * prob[c] * (
203+
probs_nb_cur[l_plus] += score * prob_c * (
171204
probs_b_prev[l] + probs_nb_prev[l])
172205
else:
173-
probs_nb_cur[l_plus] += prob[c] * (
206+
probs_nb_cur[l_plus] += prob_c * (
174207
probs_b_prev[l] + probs_nb_prev[l])
175208
# add l_plus into prefix_set_next
176209
prefix_set_next[l_plus] = probs_nb_cur[
@@ -203,6 +236,7 @@ def ctc_beam_search_decoder_nproc(probs_split,
203236
beam_size,
204237
vocabulary,
205238
blank_id=0,
239+
cutoff_prob=1.0,
206240
ext_scoring_func=None,
207241
num_processes=None):
208242
'''
@@ -216,16 +250,19 @@ def ctc_beam_search_decoder_nproc(probs_split,
216250
:type beam_size: int
217251
:param vocabulary: Vocabulary list.
218252
:type vocabulary: list
253+
:param blank_id: ID of blank, default 0.
254+
:type blank_id: int
255+
:param cutoff_prob: Cutoff probability in pruning,
256+
default 0, no pruning.
257+
:type cutoff_prob: float
219258
:param ext_scoring_func: External defined scoring function for
220259
partially decoded sentence, e.g. word count
221260
and language model.
222261
:type external_scoring_function: function
223-
:param blank_id: id of blank, default 0.
224-
:type blank_id: int
225262
:param num_processes: Number of processes, default None, equal to the
226263
number of CPUs.
227264
:type num_processes: int
228-
:return: Decoding log probability and result string.
265+
:return: Decoding log probabilities and result sentences in descending order.
229266
:rtype: list
230267
231268
'''
@@ -243,7 +280,8 @@ def ctc_beam_search_decoder_nproc(probs_split,
243280
pool = multiprocessing.Pool(processes=num_processes)
244281
results = []
245282
for i, probs_list in enumerate(probs_split):
246-
args = (probs_list, beam_size, vocabulary, blank_id, None, nproc)
283+
args = (probs_list, beam_size, vocabulary, blank_id, cutoff_prob, None,
284+
nproc)
247285
results.append(pool.apply_async(ctc_beam_search_decoder, args))
248286

249287
pool.close()

0 commit comments

Comments
 (0)