15
15
import data_utils .augmentor .trans_mean_variance_norm as trans_mean_variance_norm
16
16
import data_utils .augmentor .trans_add_delta as trans_add_delta
17
17
from data_utils .util import suppress_complaints , suppress_signal
18
+ from data_utils .util import CriticalException , ForceExitWrapper
18
19
19
20
20
21
class SampleInfo (object ):
@@ -166,6 +167,7 @@ def __init__(self,
166
167
self ._batch_buffer_size = batch_buffer_size
167
168
self ._process_num = process_num
168
169
self ._verbose = verbose
170
+ self ._force_exit = ForceExitWrapper (self ._manager .Value ('b' , False ))
169
171
170
172
def generate_bucket_list (self , is_shuffle ):
171
173
if self ._block_info_list is None :
@@ -204,15 +206,19 @@ def _sample_generator(self):
204
206
sample_queue = self ._manager .Queue (self ._sample_buffer_size )
205
207
self ._order_id = 0
206
208
207
- @suppress_complaints (verbose = self ._verbose )
209
+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
208
210
def ordered_feeding_task (sample_info_queue ):
209
211
for sample_info_bucket in self ._bucket_list :
210
- sample_info_list = sample_info_bucket .generate_sample_info_list (
211
- )
212
- self ._rng .shuffle (sample_info_list ) # do shuffle here
213
- for sample_info in sample_info_list :
214
- sample_info_queue .put ((sample_info , self ._order_id ))
215
- self ._order_id += 1
212
+ try :
213
+ sample_info_list = \
214
+ sample_info_bucket .generate_sample_info_list ()
215
+ except Exception as e :
216
+ raise CriticalException (e )
217
+ else :
218
+ self ._rng .shuffle (sample_info_list ) # do shuffle here
219
+ for sample_info in sample_info_list :
220
+ sample_info_queue .put ((sample_info , self ._order_id ))
221
+ self ._order_id += 1
216
222
217
223
for i in xrange (self ._process_num ):
218
224
sample_info_queue .put (EpochEndSignal ())
@@ -222,18 +228,21 @@ def ordered_feeding_task(sample_info_queue):
222
228
feeding_thread .daemon = True
223
229
feeding_thread .start ()
224
230
225
- @suppress_complaints (verbose = self ._verbose )
231
+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
226
232
def ordered_processing_task (sample_info_queue , sample_queue , out_order ):
227
233
if self ._verbose == 0 :
228
234
signal .signal (signal .SIGTERM , suppress_signal )
229
235
signal .signal (signal .SIGINT , suppress_signal )
230
236
231
237
def read_bytes (fpath , start , size ):
232
- f = open (fpath , 'r' )
233
- f .seek (start , 0 )
234
- binary_bytes = f .read (size )
235
- f .close ()
236
- return binary_bytes
238
+ try :
239
+ f = open (fpath , 'r' )
240
+ f .seek (start , 0 )
241
+ binary_bytes = f .read (size )
242
+ f .close ()
243
+ return binary_bytes
244
+ except Exception as e :
245
+ raise CriticalException (e )
237
246
238
247
ins = sample_info_queue .get ()
239
248
@@ -295,16 +304,20 @@ def read_bytes(fpath, start, size):
295
304
296
305
finished_process_num = 0
297
306
298
- while finished_process_num < self ._process_num :
299
- sample = sample_queue .get ()
300
- if isinstance (sample , EpochEndSignal ):
301
- finished_process_num += 1
302
- continue
303
- yield sample
307
+ while self ._force_exit == False :
308
+ try :
309
+ sample = sample_queue .get_nowait ()
310
+ except Queue .Empty :
311
+ time .sleep (0.001 )
312
+ else :
313
+ if isinstance (sample , EpochEndSignal ):
314
+ finished_process_num += 1
315
+ if finished_process_num >= self ._process_num :
316
+ break
317
+ else :
318
+ continue
304
319
305
- feeding_thread .join ()
306
- for w in workers :
307
- w .join ()
320
+ yield sample
308
321
309
322
def batch_iterator (self , batch_size , minimum_batch_size ):
310
323
def batch_to_ndarray (batch_samples , lod ):
@@ -320,7 +333,7 @@ def batch_to_ndarray(batch_samples, lod):
320
333
start += frame_num
321
334
return (batch_feature , batch_label )
322
335
323
- @suppress_complaints (verbose = self ._verbose )
336
+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
324
337
def batch_assembling_task (sample_generator , batch_queue ):
325
338
batch_samples = []
326
339
lod = [0 ]
@@ -349,7 +362,7 @@ def batch_assembling_task(sample_generator, batch_queue):
349
362
assembling_thread .daemon = True
350
363
assembling_thread .start ()
351
364
352
- while True :
365
+ while self . _force_exit == False :
353
366
try :
354
367
batch_data = batch_queue .get_nowait ()
355
368
except Queue .Empty :
@@ -358,5 +371,3 @@ def batch_assembling_task(sample_generator, batch_queue):
358
371
if isinstance (batch_data , EpochEndSignal ):
359
372
break
360
373
yield batch_data
361
-
362
- assembling_thread .join ()
0 commit comments