Skip to content

Commit dd008f1

Browse files
committed
Make batch assembling parallel.
1 parent f12deac commit dd008f1

File tree

1 file changed

+53
-32
lines changed

1 file changed

+53
-32
lines changed

fluid/DeepASR/data_utils/data_reader.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
from __future__ import print_function
66

77
import random
8-
import numpy as np
98
import struct
9+
import Queue
10+
import time
11+
import numpy as np
12+
from threading import Thread
13+
from multiprocessing import Manager, Process
1014
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
1115
import data_utils.augmentor.trans_add_delta as trans_add_delta
12-
from multiprocessing import Manager, Process
13-
from threading import Thread
14-
import time
1516

1617

1718
class SampleInfo(object):
@@ -127,6 +128,8 @@ class DataReader(object):
127128
cached.
128129
sample_info_buffer_size (int): Buffer size to indicate the maximum
129130
sample information cached.
131+
batch_buffer_size (int): Buffer size to indicate the maximum batch
132+
cached.
130133
shuffle_block_num (int): Block number indicating the minimum unit to do
131134
shuffle.
132135
random_seed (int): Random seed.
@@ -141,7 +144,8 @@ def __init__(
141144
drop_frame_len=256,
142145
process_num=10,
143146
sample_buffer_size=1024,
144-
sample_info_buffer_size=10000,
147+
sample_info_buffer_size=1024,
148+
batch_buffer_size=1024,
145149
shuffle_block_num=1,
146150
random_seed=0):
147151
self._feature_file_list = feature_file_list
@@ -158,6 +162,7 @@ def __init__(
158162
self._manager = Manager()
159163
self._sample_buffer_size = sample_buffer_size
160164
self._sample_info_buffer_size = sample_info_buffer_size
165+
self._batch_buffer_size = batch_buffer_size
161166
self._process_num = process_num
162167

163168
def generate_bucket_list(self, is_shuffle):
@@ -197,7 +202,7 @@ def _sample_generator(self):
197202
sample_queue = self._manager.Queue(self._sample_buffer_size)
198203
self._order_id = 0
199204

200-
def ordered_feeding_worker(sample_info_queue):
205+
def ordered_feeding_task(sample_info_queue):
201206
for sample_info_bucket in self._bucket_list:
202207
sample_info_list = sample_info_bucket.generate_sample_info_list(
203208
)
@@ -210,12 +215,11 @@ def ordered_feeding_worker(sample_info_queue):
210215
sample_info_queue.put(EpochEndSignal())
211216

212217
feeding_thread = Thread(
213-
target=ordered_feeding_worker, args=(sample_info_queue, ))
218+
target=ordered_feeding_task, args=(sample_info_queue, ))
214219
feeding_thread.daemon = True
215220
feeding_thread.start()
216221

217-
def ordered_processing_worker(sample_info_queue, sample_queue,
218-
out_order):
222+
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
219223
def read_bytes(fpath, start, size):
220224
f = open(fpath, 'r')
221225
f.seek(start, 0)
@@ -273,7 +277,7 @@ def read_bytes(fpath, start, size):
273277
args = (sample_info_queue, sample_queue, out_order)
274278
workers = [
275279
Process(
276-
target=ordered_processing_worker, args=args)
280+
target=ordered_processing_task, args=args)
277281
for _ in xrange(self._process_num)
278282
]
279283

@@ -295,13 +299,27 @@ def read_bytes(fpath, start, size):
295299
w.join()
296300

297301
def batch_iterator(self, batch_size, minimum_batch_size):
298-
batch_samples = []
299-
lod = [0]
300-
# check whether need parallel here
301-
for sample in self._sample_generator():
302-
batch_samples.append(sample)
303-
lod.append(lod[-1] + sample[0].shape[0])
304-
if len(batch_samples) == batch_size:
302+
def batch_assembling_task(sample_generator, batch_queue):
303+
batch_samples = []
304+
lod = [0]
305+
for sample in sample_generator():
306+
batch_samples.append(sample)
307+
lod.append(lod[-1] + sample[0].shape[0])
308+
if len(batch_samples) == batch_size:
309+
batch_feature = np.zeros(
310+
(lod[-1], self._frame_dim), dtype="float32")
311+
batch_label = np.zeros((lod[-1], 1), dtype="int64")
312+
start = 0
313+
for sample in batch_samples:
314+
frame_num = sample[0].shape[0]
315+
batch_feature[start:start + frame_num, :] = sample[0]
316+
batch_label[start:start + frame_num, :] = sample[1]
317+
start += frame_num
318+
batch_queue.put((batch_feature, batch_label, lod))
319+
batch_samples = []
320+
lod = [0]
321+
322+
if len(batch_samples) >= minimum_batch_size:
305323
batch_feature = np.zeros(
306324
(lod[-1], self._frame_dim), dtype="float32")
307325
batch_label = np.zeros((lod[-1], 1), dtype="int64")
@@ -311,18 +329,21 @@ def batch_iterator(self, batch_size, minimum_batch_size):
311329
batch_feature[start:start + frame_num, :] = sample[0]
312330
batch_label[start:start + frame_num, :] = sample[1]
313331
start += frame_num
314-
yield (batch_feature, batch_label, lod)
315-
batch_samples = []
316-
lod = [0]
317-
318-
if len(batch_samples) >= minimum_batch_size:
319-
batch_feature = np.zeros(
320-
(lod[-1], self._frame_dim), dtype="float32")
321-
batch_label = np.zeros((lod[-1], 1), dtype="int64")
322-
start = 0
323-
for sample in batch_samples:
324-
frame_num = sample[0].shape[0]
325-
batch_feature[start:start + frame_num, :] = sample[0]
326-
batch_label[start:start + frame_num, :] = sample[1]
327-
start += frame_num
328-
yield (batch_feature, batch_label, lod)
332+
batch_queue.put((batch_feature, batch_label, lod))
333+
334+
batch_queue.put(EpochEndSignal())
335+
336+
batch_queue = Queue.Queue(self._batch_buffer_size)
337+
338+
assembling_thread = Thread(
339+
target=batch_assembling_task,
340+
args=(self._sample_generator, batch_queue))
341+
assembling_thread.daemon = True
342+
assembling_thread.start()
343+
344+
batch_data = batch_queue.get()
345+
while not isinstance(batch_data, EpochEndSignal):
346+
yield batch_data
347+
batch_data = batch_queue.get()
348+
349+
assembling_thread.join()

0 commit comments

Comments
 (0)