5
5
from __future__ import print_function
6
6
7
7
import random
8
- import numpy as np
9
8
import struct
9
+ import Queue
10
+ import time
11
+ import numpy as np
12
+ from threading import Thread
13
+ from multiprocessing import Manager , Process
10
14
import data_utils .augmentor .trans_mean_variance_norm as trans_mean_variance_norm
11
15
import data_utils .augmentor .trans_add_delta as trans_add_delta
12
- from multiprocessing import Manager , Process
13
- from threading import Thread
14
- import time
15
16
16
17
17
18
class SampleInfo (object ):
@@ -127,6 +128,8 @@ class DataReader(object):
127
128
cached.
128
129
sample_info_buffer_size (int): Buffer size to indicate the maximum
129
130
sample information cached.
131
+ batch_buffer_size (int): Buffer size to indicate the maximum batch
132
+ cached.
130
133
shuffle_block_num (int): Block number indicating the minimum unit to do
131
134
shuffle.
132
135
random_seed (int): Random seed.
@@ -141,7 +144,8 @@ def __init__(
141
144
drop_frame_len = 256 ,
142
145
process_num = 10 ,
143
146
sample_buffer_size = 1024 ,
144
- sample_info_buffer_size = 10000 ,
147
+ sample_info_buffer_size = 1024 ,
148
+ batch_buffer_size = 1024 ,
145
149
shuffle_block_num = 1 ,
146
150
random_seed = 0 ):
147
151
self ._feature_file_list = feature_file_list
@@ -158,6 +162,7 @@ def __init__(
158
162
self ._manager = Manager ()
159
163
self ._sample_buffer_size = sample_buffer_size
160
164
self ._sample_info_buffer_size = sample_info_buffer_size
165
+ self ._batch_buffer_size = batch_buffer_size
161
166
self ._process_num = process_num
162
167
163
168
def generate_bucket_list (self , is_shuffle ):
@@ -197,7 +202,7 @@ def _sample_generator(self):
197
202
sample_queue = self ._manager .Queue (self ._sample_buffer_size )
198
203
self ._order_id = 0
199
204
200
- def ordered_feeding_worker (sample_info_queue ):
205
+ def ordered_feeding_task (sample_info_queue ):
201
206
for sample_info_bucket in self ._bucket_list :
202
207
sample_info_list = sample_info_bucket .generate_sample_info_list (
203
208
)
@@ -210,12 +215,11 @@ def ordered_feeding_worker(sample_info_queue):
210
215
sample_info_queue .put (EpochEndSignal ())
211
216
212
217
feeding_thread = Thread (
213
- target = ordered_feeding_worker , args = (sample_info_queue , ))
218
+ target = ordered_feeding_task , args = (sample_info_queue , ))
214
219
feeding_thread .daemon = True
215
220
feeding_thread .start ()
216
221
217
- def ordered_processing_worker (sample_info_queue , sample_queue ,
218
- out_order ):
222
+ def ordered_processing_task (sample_info_queue , sample_queue , out_order ):
219
223
def read_bytes (fpath , start , size ):
220
224
f = open (fpath , 'r' )
221
225
f .seek (start , 0 )
@@ -273,7 +277,7 @@ def read_bytes(fpath, start, size):
273
277
args = (sample_info_queue , sample_queue , out_order )
274
278
workers = [
275
279
Process (
276
- target = ordered_processing_worker , args = args )
280
+ target = ordered_processing_task , args = args )
277
281
for _ in xrange (self ._process_num )
278
282
]
279
283
@@ -295,13 +299,27 @@ def read_bytes(fpath, start, size):
295
299
w .join ()
296
300
297
301
def batch_iterator (self , batch_size , minimum_batch_size ):
298
- batch_samples = []
299
- lod = [0 ]
300
- # check whether need parallel here
301
- for sample in self ._sample_generator ():
302
- batch_samples .append (sample )
303
- lod .append (lod [- 1 ] + sample [0 ].shape [0 ])
304
- if len (batch_samples ) == batch_size :
302
+ def batch_assembling_task (sample_generator , batch_queue ):
303
+ batch_samples = []
304
+ lod = [0 ]
305
+ for sample in sample_generator ():
306
+ batch_samples .append (sample )
307
+ lod .append (lod [- 1 ] + sample [0 ].shape [0 ])
308
+ if len (batch_samples ) == batch_size :
309
+ batch_feature = np .zeros (
310
+ (lod [- 1 ], self ._frame_dim ), dtype = "float32" )
311
+ batch_label = np .zeros ((lod [- 1 ], 1 ), dtype = "int64" )
312
+ start = 0
313
+ for sample in batch_samples :
314
+ frame_num = sample [0 ].shape [0 ]
315
+ batch_feature [start :start + frame_num , :] = sample [0 ]
316
+ batch_label [start :start + frame_num , :] = sample [1 ]
317
+ start += frame_num
318
+ batch_queue .put ((batch_feature , batch_label , lod ))
319
+ batch_samples = []
320
+ lod = [0 ]
321
+
322
+ if len (batch_samples ) >= minimum_batch_size :
305
323
batch_feature = np .zeros (
306
324
(lod [- 1 ], self ._frame_dim ), dtype = "float32" )
307
325
batch_label = np .zeros ((lod [- 1 ], 1 ), dtype = "int64" )
@@ -311,18 +329,21 @@ def batch_iterator(self, batch_size, minimum_batch_size):
311
329
batch_feature [start :start + frame_num , :] = sample [0 ]
312
330
batch_label [start :start + frame_num , :] = sample [1 ]
313
331
start += frame_num
314
- yield (batch_feature , batch_label , lod )
315
- batch_samples = []
316
- lod = [0 ]
317
-
318
- if len (batch_samples ) >= minimum_batch_size :
319
- batch_feature = np .zeros (
320
- (lod [- 1 ], self ._frame_dim ), dtype = "float32" )
321
- batch_label = np .zeros ((lod [- 1 ], 1 ), dtype = "int64" )
322
- start = 0
323
- for sample in batch_samples :
324
- frame_num = sample [0 ].shape [0 ]
325
- batch_feature [start :start + frame_num , :] = sample [0 ]
326
- batch_label [start :start + frame_num , :] = sample [1 ]
327
- start += frame_num
328
- yield (batch_feature , batch_label , lod )
332
+ batch_queue .put ((batch_feature , batch_label , lod ))
333
+
334
+ batch_queue .put (EpochEndSignal ())
335
+
336
+ batch_queue = Queue .Queue (self ._batch_buffer_size )
337
+
338
+ assembling_thread = Thread (
339
+ target = batch_assembling_task ,
340
+ args = (self ._sample_generator , batch_queue ))
341
+ assembling_thread .daemon = True
342
+ assembling_thread .start ()
343
+
344
+ batch_data = batch_queue .get ()
345
+ while not isinstance (batch_data , EpochEndSignal ):
346
+ yield batch_data
347
+ batch_data = batch_queue .get ()
348
+
349
+ assembling_thread .join ()
0 commit comments