1515import data_utils .augmentor .trans_mean_variance_norm as trans_mean_variance_norm
1616import data_utils .augmentor .trans_add_delta as trans_add_delta
1717from data_utils .util import suppress_complaints , suppress_signal
18+ from data_utils .util import CriticalException , ForceExitWrapper
1819
1920
2021class SampleInfo (object ):
@@ -166,6 +167,7 @@ def __init__(self,
166167 self ._batch_buffer_size = batch_buffer_size
167168 self ._process_num = process_num
168169 self ._verbose = verbose
170+ self ._force_exit = ForceExitWrapper (self ._manager .Value ('b' , False ))
169171
170172 def generate_bucket_list (self , is_shuffle ):
171173 if self ._block_info_list is None :
@@ -204,15 +206,19 @@ def _sample_generator(self):
204206 sample_queue = self ._manager .Queue (self ._sample_buffer_size )
205207 self ._order_id = 0
206208
207- @suppress_complaints (verbose = self ._verbose )
209+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
208210 def ordered_feeding_task (sample_info_queue ):
209211 for sample_info_bucket in self ._bucket_list :
210- sample_info_list = sample_info_bucket .generate_sample_info_list (
211- )
212- self ._rng .shuffle (sample_info_list ) # do shuffle here
213- for sample_info in sample_info_list :
214- sample_info_queue .put ((sample_info , self ._order_id ))
215- self ._order_id += 1
212+ try :
213+ sample_info_list = \
214+ sample_info_bucket .generate_sample_info_list ()
215+ except Exception as e :
216+ raise CriticalException (e )
217+ else :
218+ self ._rng .shuffle (sample_info_list ) # do shuffle here
219+ for sample_info in sample_info_list :
220+ sample_info_queue .put ((sample_info , self ._order_id ))
221+ self ._order_id += 1
216222
217223 for i in xrange (self ._process_num ):
218224 sample_info_queue .put (EpochEndSignal ())
@@ -222,18 +228,21 @@ def ordered_feeding_task(sample_info_queue):
222228 feeding_thread .daemon = True
223229 feeding_thread .start ()
224230
225- @suppress_complaints (verbose = self ._verbose )
231+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
226232 def ordered_processing_task (sample_info_queue , sample_queue , out_order ):
227233 if self ._verbose == 0 :
228234 signal .signal (signal .SIGTERM , suppress_signal )
229235 signal .signal (signal .SIGINT , suppress_signal )
230236
231237 def read_bytes (fpath , start , size ):
232- f = open (fpath , 'r' )
233- f .seek (start , 0 )
234- binary_bytes = f .read (size )
235- f .close ()
236- return binary_bytes
238+ try :
239+ f = open (fpath , 'r' )
240+ f .seek (start , 0 )
241+ binary_bytes = f .read (size )
242+ f .close ()
243+ return binary_bytes
244+ except Exception as e :
245+ raise CriticalException (e )
237246
238247 ins = sample_info_queue .get ()
239248
@@ -295,16 +304,20 @@ def read_bytes(fpath, start, size):
295304
296305 finished_process_num = 0
297306
298- while finished_process_num < self ._process_num :
299- sample = sample_queue .get ()
300- if isinstance (sample , EpochEndSignal ):
301- finished_process_num += 1
302- continue
303- yield sample
307+ while self ._force_exit == False :
308+ try :
309+ sample = sample_queue .get_nowait ()
310+ except Queue .Empty :
311+ time .sleep (0.001 )
312+ else :
313+ if isinstance (sample , EpochEndSignal ):
314+ finished_process_num += 1
315+ if finished_process_num >= self ._process_num :
316+ break
317+ else :
318+ continue
304319
305- feeding_thread .join ()
306- for w in workers :
307- w .join ()
320+ yield sample
308321
309322 def batch_iterator (self , batch_size , minimum_batch_size ):
310323 def batch_to_ndarray (batch_samples , lod ):
@@ -320,7 +333,7 @@ def batch_to_ndarray(batch_samples, lod):
320333 start += frame_num
321334 return (batch_feature , batch_label )
322335
323- @suppress_complaints (verbose = self ._verbose )
336+ @suppress_complaints (verbose = self ._verbose , notify = self . _force_exit )
324337 def batch_assembling_task (sample_generator , batch_queue ):
325338 batch_samples = []
326339 lod = [0 ]
@@ -349,7 +362,7 @@ def batch_assembling_task(sample_generator, batch_queue):
349362 assembling_thread .daemon = True
350363 assembling_thread .start ()
351364
352- while True :
365+ while self . _force_exit == False :
353366 try :
354367 batch_data = batch_queue .get_nowait ()
355368 except Queue .Empty :
@@ -358,5 +371,3 @@ def batch_assembling_task(sample_generator, batch_queue):
358371 if isinstance (batch_data , EpochEndSignal ):
359372 break
360373 yield batch_data
361-
362- assembling_thread .join ()
0 commit comments