Skip to content

Commit 2063703

Browse files
author
Xinghai Sun
committed
Support padding removing.
1 parent 1dc445f commit 2063703

File tree

6 files changed

+109
-22
lines changed

6 files changed

+109
-22
lines changed

deep_speech_2/data_utils/data.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ class DataGenerator(object):
5959
be passed forward directly without
6060
converting to index sequence.
6161
:type keep_transcription_text: bool
62+
:param num_conv_layers: The number of convolution layer, used to compute
63+
the sequence length.
64+
:type num_conv_layers: int
6265
"""
6366

6467
def __init__(self,
@@ -74,7 +77,8 @@ def __init__(self,
7477
use_dB_normalization=True,
7578
num_threads=multiprocessing.cpu_count() // 2,
7679
random_seed=0,
77-
keep_transcription_text=False):
80+
keep_transcription_text=False,
81+
num_conv_layers=2):
7882
self._max_duration = max_duration
7983
self._min_duration = min_duration
8084
self._normalizer = FeatureNormalizer(mean_std_filepath)
@@ -95,6 +99,7 @@ def __init__(self,
9599
self._local_data = local()
96100
self._local_data.tar2info = {}
97101
self._local_data.tar2object = {}
102+
self._num_conv_layers = num_conv_layers
98103

99104
def process_utterance(self, filename, transcript):
100105
"""Load, augment, featurize and normalize for speech data.
@@ -213,7 +218,15 @@ def feeding(self):
213218
:return: Data feeding dict.
214219
:rtype: dict
215220
"""
216-
return {"audio_spectrogram": 0, "transcript_text": 1}
221+
feeding_dict = {
222+
"audio_spectrogram": 0,
223+
"transcript_text": 1,
224+
"sequence_offset": 2,
225+
"sequence_length": 3
226+
}
227+
for i in xrange(self._num_conv_layers):
228+
feeding_dict["conv%d_index_range" % i] = len(feeding_dict)
229+
return feeding_dict
217230

218231
@property
219232
def vocab_size(self):
@@ -306,7 +319,25 @@ def _padding_batch(self, batch, padding_to=-1, flatten=False):
306319
padded_audio[:, :audio.shape[1]] = audio
307320
if flatten:
308321
padded_audio = padded_audio.flatten()
309-
new_batch.append((padded_audio, text))
322+
323+
padded_instance = [padded_audio, text]
324+
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
325+
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
326+
valid_w = (audio.shape[1] - 1) // 3 + 1
327+
padded_instance += [
328+
[0], # sequence offset, always 0
329+
[valid_w], # valid sequence length
330+
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
331+
]
332+
pre_padded_h = padded_conv0_h
333+
for i in xrange(self._num_conv_layers - 1):
334+
padded_h = (pre_padded_h - 1) // 2 + 1
335+
pre_padded_h = padded_h
336+
padded_instance += [
337+
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
338+
]
339+
340+
new_batch.append(padded_instance)
310341
return new_batch
311342

312343
def _batch_shuffle(self, manifest, batch_size, clipped=False):

deep_speech_2/infer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def infer():
6969
augmentation_config='{}',
7070
specgram_type=args.specgram_type,
7171
num_threads=1,
72-
keep_transcription_text=True)
72+
keep_transcription_text=True,
73+
num_conv_layers=args.num_conv_layers)
7374
batch_reader = data_generator.batch_reader_creator(
7475
manifest_path=args.infer_manifest,
7576
batch_size=args.num_samples,
@@ -100,10 +101,11 @@ def infer():
100101
cutoff_top_n=args.cutoff_top_n,
101102
vocab_list=vocab_list,
102103
language_model_path=args.lang_model_path,
103-
num_processes=args.num_proc_bsearch)
104+
num_processes=args.num_proc_bsearch,
105+
feeding_dict=data_generator.feeding)
104106

105107
error_rate_func = cer if args.error_rate_type == 'cer' else wer
106-
target_transcripts = [transcript for _, transcript in infer_data]
108+
target_transcripts = [data[1] for data in infer_data]
107109
for target, result in zip(target_transcripts, result_transcripts):
108110
print("\nTarget Transcription: %s\nOutput Transcription: %s" %
109111
(target, result))

deep_speech_2/model_utils/model.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def infer_loss_batch(self, infer_data):
165165

166166
def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
167167
beam_size, cutoff_prob, cutoff_top_n, vocab_list,
168-
language_model_path, num_processes):
168+
language_model_path, num_processes, feeding_dict):
169169
"""Model inference. Infer the transcription for a batch of speech
170170
utterances.
171171
@@ -195,6 +195,9 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
195195
:type language_model_path: basestring|None
196196
:param num_processes: Number of processes (CPU) for decoder.
197197
:type num_processes: int
198+
:param feeding_dict: Feeding is a map of field name and tuple index
199+
of the data that reader returns.
200+
:type feeding_dict: dict|list
198201
:return: List of transcription texts.
199202
:rtype: List of basestring
200203
"""
@@ -203,10 +206,13 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
203206
self._inferer = paddle.inference.Inference(
204207
output_layer=self._log_probs, parameters=self._parameters)
205208
# run inference
206-
infer_results = self._inferer.infer(input=infer_data)
207-
num_steps = len(infer_results) // len(infer_data)
209+
infer_results = self._inferer.infer(
210+
input=infer_data, feeding=feeding_dict)
211+
start_pos = [0] * (len(infer_data) + 1)
212+
for i in xrange(len(infer_data)):
213+
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
208214
probs_split = [
209-
infer_results[i * num_steps:(i + 1) * num_steps]
215+
infer_results[start_pos[i]:start_pos[i + 1]]
210216
for i in xrange(0, len(infer_data))
211217
]
212218
# run decoder
@@ -274,9 +280,25 @@ def _create_network(self, vocab_size, num_conv_layers, num_rnn_layers,
274280
text_data = paddle.layer.data(
275281
name="transcript_text",
276282
type=paddle.data_type.integer_value_sequence(vocab_size))
283+
seq_offset_data = paddle.layer.data(
284+
name='sequence_offset',
285+
type=paddle.data_type.integer_value_sequence(1))
286+
seq_len_data = paddle.layer.data(
287+
name='sequence_length',
288+
type=paddle.data_type.integer_value_sequence(1))
289+
index_range_datas = []
290+
for i in xrange(num_rnn_layers):
291+
index_range_datas.append(
292+
paddle.layer.data(
293+
name='conv%d_index_range' % i,
294+
type=paddle.data_type.dense_vector(6)))
295+
277296
self._log_probs, self._loss = deep_speech_v2_network(
278297
audio_data=audio_data,
279298
text_data=text_data,
299+
seq_offset_data=seq_offset_data,
300+
seq_len_data=seq_len_data,
301+
index_range_datas=index_range_datas,
280302
dict_size=vocab_size,
281303
num_conv_layers=num_conv_layers,
282304
num_rnn_layers=num_rnn_layers,

deep_speech_2/model_utils/network.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
10-
padding, act):
10+
padding, act, index_range_data):
1111
"""Convolution layer with batch normalization.
1212
1313
:param input: Input layer.
@@ -24,6 +24,8 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
2424
:type padding: int|tuple|list
2525
:param act: Activation type.
2626
:type act: BaseActivation
27+
:param index_range_data: Index range to indicate sub region.
28+
:type index_range_data: LayerOutput
2729
:return: Batch norm layer after convolution layer.
2830
:rtype: LayerOutput
2931
"""
@@ -36,7 +38,11 @@ def conv_bn_layer(input, filter_size, num_channels_in, num_channels_out, stride,
3638
padding=padding,
3739
act=paddle.activation.Linear(),
3840
bias_attr=False)
39-
return paddle.layer.batch_norm(input=conv_layer, act=act)
41+
batch_norm = paddle.layer.batch_norm(input=conv_layer, act=act)
42+
# reset padding part to 0
43+
scale_sub_region = paddle.layer.scale_sub_region(
44+
batch_norm, index_range_data, value=0.0)
45+
return scale_sub_region
4046

4147

4248
def bidirectional_simple_rnn_bn_layer(name, input, size, act, share_weights):
@@ -136,13 +142,15 @@ def bidirectional_gru_bn_layer(name, input, size, act):
136142
return paddle.layer.concat(input=[forward_gru, backward_gru])
137143

138144

139-
def conv_group(input, num_stacks):
145+
def conv_group(input, num_stacks, index_range_datas):
140146
"""Convolution group with stacked convolution layers.
141147
142148
:param input: Input layer.
143149
:type input: LayerOutput
144150
:param num_stacks: Number of stacked convolution layers.
145151
:type num_stacks: int
152+
:param index_range_datas: Index ranges for each convolution layer.
153+
:type index_range_datas: tuple|list
146154
:return: Output layer of the convolution group.
147155
:rtype: LayerOutput
148156
"""
@@ -153,7 +161,8 @@ def conv_group(input, num_stacks):
153161
num_channels_out=32,
154162
stride=(3, 2),
155163
padding=(5, 20),
156-
act=paddle.activation.BRelu())
164+
act=paddle.activation.BRelu(),
165+
index_range_data=index_range_datas[0])
157166
for i in xrange(num_stacks - 1):
158167
conv = conv_bn_layer(
159168
input=conv,
@@ -162,7 +171,8 @@ def conv_group(input, num_stacks):
162171
num_channels_out=32,
163172
stride=(1, 2),
164173
padding=(5, 10),
165-
act=paddle.activation.BRelu())
174+
act=paddle.activation.BRelu(),
175+
index_range_data=index_range_datas[i + 1])
166176
output_num_channels = 32
167177
output_height = 160 // pow(2, num_stacks) + 1
168178
return conv, output_num_channels, output_height
@@ -207,6 +217,9 @@ def rnn_group(input, size, num_stacks, use_gru, share_rnn_weights):
207217

208218
def deep_speech_v2_network(audio_data,
209219
text_data,
220+
seq_offset_data,
221+
seq_len_data,
222+
index_range_datas,
210223
dict_size,
211224
num_conv_layers=2,
212225
num_rnn_layers=3,
@@ -219,6 +232,12 @@ def deep_speech_v2_network(audio_data,
219232
:type audio_data: LayerOutput
220233
:param text_data: Transcription text data layer.
221234
:type text_data: LayerOutput
235+
:param seq_offset_data: Sequence offset data layer.
236+
:type seq_offset_data: LayerOutput
237+
:param seq_len_data: Valid sequence length data layer.
238+
:type seq_len_data: LayerOutput
239+
:param index_range_datas: Index ranges data layers.
240+
:type index_range_datas: tuple|list
222241
:param dict_size: Dictionary size for tokenized transcription.
223242
:type dict_size: int
224243
:param num_conv_layers: Number of stacking convolution layers.
@@ -239,7 +258,9 @@ def deep_speech_v2_network(audio_data,
239258
"""
240259
# convolution group
241260
conv_group_output, conv_group_num_channels, conv_group_height = conv_group(
242-
input=audio_data, num_stacks=num_conv_layers)
261+
input=audio_data,
262+
num_stacks=num_conv_layers,
263+
index_range_datas=index_range_datas)
243264
# convert data form convolution feature map to sequence of vectors
244265
conv2seq = paddle.layer.block_expand(
245266
input=conv_group_output,
@@ -248,9 +269,16 @@ def deep_speech_v2_network(audio_data,
248269
stride_y=1,
249270
block_x=1,
250271
block_y=conv_group_height)
272+
# remove padding part
273+
remove_padding = paddle.layer.sub_seq(
274+
input=conv2seq,
275+
offsets=seq_offset_data,
276+
sizes=seq_len_data,
277+
act=paddle.activation.Linear(),
278+
bias_attr=False)
251279
# rnn group
252280
rnn_group_output = rnn_group(
253-
input=conv2seq,
281+
input=remove_padding,
254282
size=rnn_size,
255283
num_stacks=num_rnn_layers,
256284
use_gru=use_gru,

deep_speech_2/test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def evaluate():
7070
augmentation_config='{}',
7171
specgram_type=args.specgram_type,
7272
num_threads=args.num_proc_data,
73-
keep_transcription_text=True)
73+
keep_transcription_text=True,
74+
num_conv_layers=args.num_conv_layers)
7475
batch_reader = data_generator.batch_reader_creator(
7576
manifest_path=args.test_manifest,
7677
batch_size=args.batch_size,
@@ -103,8 +104,9 @@ def evaluate():
103104
cutoff_top_n=args.cutoff_top_n,
104105
vocab_list=vocab_list,
105106
language_model_path=args.lang_model_path,
106-
num_processes=args.num_proc_bsearch)
107-
target_transcripts = [transcript for _, transcript in infer_data]
107+
num_processes=args.num_proc_bsearch,
108+
feeding_dict=data_generator.feeding)
109+
target_transcripts = [data[1] for data in infer_data]
108110
for target, result in zip(target_transcripts, result_transcripts):
109111
error_sum += error_rate_func(target, result)
110112
num_ins += 1

deep_speech_2/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,15 @@ def train():
7575
max_duration=args.max_duration,
7676
min_duration=args.min_duration,
7777
specgram_type=args.specgram_type,
78-
num_threads=args.num_proc_data)
78+
num_threads=args.num_proc_data,
79+
num_conv_layers=args.num_conv_layers)
7980
dev_generator = DataGenerator(
8081
vocab_filepath=args.vocab_path,
8182
mean_std_filepath=args.mean_std_path,
8283
augmentation_config="{}",
8384
specgram_type=args.specgram_type,
84-
num_threads=args.num_proc_data)
85+
num_threads=args.num_proc_data,
86+
num_conv_layers=args.num_conv_layers)
8587
train_batch_reader = train_generator.batch_reader_creator(
8688
manifest_path=args.train_manifest,
8789
batch_size=args.batch_size,

0 commit comments

Comments
 (0)