Skip to content

Commit 2d16fbc

Browse files
author
Yibing Liu
authored
Merge pull request #447 from kuke/adapt_tuning
Adapt tuning script to padding removing #444
2 parents 493e8e8 + 514f4ef commit 2d16fbc

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

deep_speech_2/tools/tune.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,34 @@ def tune():
8888
augmentation_config='{}',
8989
specgram_type=args.specgram_type,
9090
num_threads=args.num_proc_data,
91-
keep_transcription_text=True)
91+
keep_transcription_text=True,
92+
num_conv_layers=args.num_conv_layers)
9293

9394
audio_data = paddle.layer.data(
9495
name="audio_spectrogram",
9596
type=paddle.data_type.dense_array(161 * 161))
9697
text_data = paddle.layer.data(
9798
name="transcript_text",
9899
type=paddle.data_type.integer_value_sequence(data_generator.vocab_size))
100+
seq_offset_data = paddle.layer.data(
101+
name='sequence_offset',
102+
type=paddle.data_type.integer_value_sequence(1))
103+
seq_len_data = paddle.layer.data(
104+
name='sequence_length',
105+
type=paddle.data_type.integer_value_sequence(1))
106+
index_range_datas = []
107+
for i in xrange(args.num_rnn_layers):
108+
index_range_datas.append(
109+
paddle.layer.data(
110+
name='conv%d_index_range' % i,
111+
type=paddle.data_type.dense_vector(6)))
99112

100113
output_probs, _ = deep_speech_v2_network(
101114
audio_data=audio_data,
102115
text_data=text_data,
116+
seq_offset_data=seq_offset_data,
117+
seq_len_data=seq_len_data,
118+
index_range_datas=index_range_datas,
103119
dict_size=data_generator.vocab_size,
104120
num_conv_layers=args.num_conv_layers,
105121
num_rnn_layers=args.num_rnn_layers,
@@ -156,15 +172,17 @@ def tune():
156172
for infer_data in batch_reader():
157173
if (args.num_batches >= 0) and (cur_batch >= args.num_batches):
158174
break
159-
infer_results = inferer.infer(input=infer_data)
160-
161-
num_steps = len(infer_results) // len(infer_data)
175+
infer_results = inferer.infer(input=infer_data,
176+
feeding=data_generator.feeding)
177+
start_pos = [0] * (len(infer_data) + 1)
178+
for i in xrange(len(infer_data)):
179+
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
162180
probs_split = [
163-
infer_results[i * num_steps:(i + 1) * num_steps]
164-
for i in xrange(len(infer_data))
181+
infer_results[start_pos[i]:start_pos[i + 1]]
182+
for i in xrange(0, len(infer_data))
165183
]
166184

167-
target_transcripts = [transcript for _, transcript in infer_data]
185+
target_transcripts = [ data[1] for data in infer_data ]
168186

169187
num_ins += len(target_transcripts)
170188
# grid search

0 commit comments

Comments
 (0)