@@ -364,18 +364,35 @@ def on_epoch_end(self):
364
364
"""
365
365
pass
366
366
367
+ # Global variables to be shared across processes
368
+ _SHARED_SEQUENCE = None
369
+ _MANAGER = multiprocessing .Manager ()
370
+ _SHARED_DICT = _MANAGER .dict ()
367
371
368
- def get_index (ds , i ):
369
- """Quick fix for Python2, otherwise, it cannot be pickled.
372
+
373
+ def get_index (i ):
374
+ """Get the value from the Sequence at index `i`.
370
375
371
376
# Arguments
372
- ds: a Sequence object
373
377
i: index
374
378
375
379
# Returns
376
380
The value at index `i`.
377
381
"""
378
- return ds [i ]
382
+ global _SHARED_SEQUENCE
383
+ return _SHARED_SEQUENCE [i ]
384
+
385
+
386
+ def _update_sequence (seq ):
387
+ """Update current process with a new Sequence.
388
+
389
+ # Arguments
390
+ seq: Sequence object
391
+ """
392
+ global _SHARED_SEQUENCE , _SHARED_DICT
393
+ if not multiprocessing .current_process ().pid in _SHARED_DICT :
394
+ _SHARED_SEQUENCE = seq
395
+ _SHARED_DICT [multiprocessing .current_process ().pid ] = 0
379
396
380
397
381
398
class SequenceEnqueuer (object ):
@@ -477,6 +494,7 @@ def start(self, workers=1, max_queue_size=10):
477
494
self .executor = multiprocessing .Pool (workers )
478
495
else :
479
496
self .executor = ThreadPool (workers )
497
+ self .workers = workers
480
498
self .queue = queue .Queue (max_queue_size )
481
499
self .stop_signal = threading .Event ()
482
500
self .run_thread = threading .Thread (target = self ._run )
@@ -486,17 +504,18 @@ def start(self, workers=1, max_queue_size=10):
486
504
def _run (self ):
487
505
"""Function to submit request to the executor and queue the `Future` objects."""
488
506
sequence = list (range (len (self .sequence )))
507
+ self ._send_sequence () # Share the initial sequence
489
508
while True :
490
509
if self .shuffle :
491
510
random .shuffle (sequence )
492
511
for i in sequence :
493
512
if self .stop_signal .is_set ():
494
513
return
495
514
self .queue .put (
496
- self .executor .apply_async (get_index ,
497
- (self .sequence , i )), block = True )
515
+ self .executor .apply_async (get_index , (i ,)), block = True )
498
516
# Call the internal on epoch end.
499
517
self .sequence .on_epoch_end ()
518
+ self ._send_sequence () # Update the pool
500
519
501
520
def get (self ):
502
521
"""Creates a generator to extract data from the queue.
@@ -516,6 +535,19 @@ def get(self):
516
535
self .stop ()
517
536
raise StopIteration (e )
518
537
538
+ def _send_sequence (self ):
539
+ """Send current Sequence to all workers."""
540
+ global _SHARED_SEQUENCE
541
+ _SHARED_SEQUENCE = self .sequence # For new processes that may spawn
542
+ if not self .use_multiprocessing :
543
+ # Threads are from the same process so they already share the sequence.
544
+ return
545
+ _SHARED_DICT .clear ()
546
+ while len (_SHARED_DICT ) < self .workers and not self .stop_signal .is_set ():
547
+ # Ask the pool to update till everyone is updated.
548
+ self .executor .apply (_update_sequence , args = (self .sequence ,))
549
+ # We're done with the update
550
+
519
551
def stop (self , timeout = None ):
520
552
"""Stops running threads and wait for them to exit, if necessary.
521
553
0 commit comments