Skip to content

Commit 19e1be2

Browse files
Frédéric Branchaud-Charronfchollet
authored andcommitted
Faster sequence (keras-team#8039)
* Make Sequence faster * Don't update if we're done * Change according to review
1 parent ab57009 commit 19e1be2

File tree

2 files changed

+44
-12
lines changed

2 files changed

+44
-12
lines changed

keras/utils/data_utils.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,18 +364,35 @@ def on_epoch_end(self):
364364
"""
365365
pass
366366

367+
# Global variables to be shared across processes
368+
_SHARED_SEQUENCE = None
369+
_MANAGER = multiprocessing.Manager()
370+
_SHARED_DICT = _MANAGER.dict()
367371

368-
def get_index(ds, i):
369-
"""Quick fix for Python2, otherwise, it cannot be pickled.
372+
373+
def get_index(i):
374+
"""Get the value from the Sequence at index `i`.
370375
371376
# Arguments
372-
ds: a Sequence object
373377
i: index
374378
375379
# Returns
376380
The value at index `i`.
377381
"""
378-
return ds[i]
382+
global _SHARED_SEQUENCE
383+
return _SHARED_SEQUENCE[i]
384+
385+
386+
def _update_sequence(seq):
387+
"""Update current process with a new Sequence.
388+
389+
# Arguments
390+
seq: Sequence object
391+
"""
392+
global _SHARED_SEQUENCE, _SHARED_DICT
393+
if not multiprocessing.current_process().pid in _SHARED_DICT:
394+
_SHARED_SEQUENCE = seq
395+
_SHARED_DICT[multiprocessing.current_process().pid] = 0
379396

380397

381398
class SequenceEnqueuer(object):
@@ -477,6 +494,7 @@ def start(self, workers=1, max_queue_size=10):
477494
self.executor = multiprocessing.Pool(workers)
478495
else:
479496
self.executor = ThreadPool(workers)
497+
self.workers = workers
480498
self.queue = queue.Queue(max_queue_size)
481499
self.stop_signal = threading.Event()
482500
self.run_thread = threading.Thread(target=self._run)
@@ -486,17 +504,18 @@ def start(self, workers=1, max_queue_size=10):
486504
def _run(self):
487505
"""Function to submit request to the executor and queue the `Future` objects."""
488506
sequence = list(range(len(self.sequence)))
507+
self._send_sequence() # Share the initial sequence
489508
while True:
490509
if self.shuffle:
491510
random.shuffle(sequence)
492511
for i in sequence:
493512
if self.stop_signal.is_set():
494513
return
495514
self.queue.put(
496-
self.executor.apply_async(get_index,
497-
(self.sequence, i)), block=True)
515+
self.executor.apply_async(get_index, (i,)), block=True)
498516
# Call the internal on epoch end.
499517
self.sequence.on_epoch_end()
518+
self._send_sequence() # Update the pool
500519

501520
def get(self):
502521
"""Creates a generator to extract data from the queue.
@@ -516,6 +535,19 @@ def get(self):
516535
self.stop()
517536
raise StopIteration(e)
518537

538+
def _send_sequence(self):
539+
"""Send current Sequence to all workers."""
540+
global _SHARED_SEQUENCE
541+
_SHARED_SEQUENCE = self.sequence # For new processes that may spawn
542+
if not self.use_multiprocessing:
543+
# Threads are from the same process so they already share the sequence.
544+
return
545+
_SHARED_DICT.clear()
546+
while len(_SHARED_DICT) < self.workers and not self.stop_signal.is_set():
547+
# Ask the pool to update till everyone is updated.
548+
self.executor.apply(_update_sequence, args=(self.sequence,))
549+
# We're done with the update
550+
519551
def stop(self, timeout=None):
520552
"""Stops running threads and wait for them to exit, if necessary.
521553

tests/keras/utils/data_utils_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def g(*a, **kw):
111111
return g
112112

113113

114-
class TestSequence(Sequence):
114+
class DummySequence(Sequence):
115115
def __init__(self, shape):
116116
self.shape = shape
117117

@@ -149,7 +149,7 @@ def create_generator_from_sequence_pcs(ds):
149149

150150
def test_generator_enqueuer_threads():
151151
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_threads(
152-
TestSequence([3, 200, 200, 3])), use_multiprocessing=False)
152+
DummySequence([3, 200, 200, 3])), use_multiprocessing=False)
153153
enqueuer.start(3, 10)
154154
gen_output = enqueuer.get()
155155
acc = []
@@ -166,7 +166,7 @@ def test_generator_enqueuer_threads():
166166

167167
def test_generator_enqueuer_processes():
168168
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
169-
TestSequence([3, 200, 200, 3])), use_multiprocessing=True)
169+
DummySequence([3, 200, 200, 3])), use_multiprocessing=True)
170170
enqueuer.start(3, 10)
171171
gen_output = enqueuer.get()
172172
acc = []
@@ -195,7 +195,7 @@ def test_generator_enqueuer_fail_processes():
195195

196196

197197
def test_ordered_enqueuer_threads():
198-
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
198+
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=False)
199199
enqueuer.start(3, 10)
200200
gen_output = enqueuer.get()
201201
acc = []
@@ -206,7 +206,7 @@ def test_ordered_enqueuer_threads():
206206

207207

208208
def test_ordered_enqueuer_threads_not_ordered():
209-
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]),
209+
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]),
210210
use_multiprocessing=False,
211211
shuffle=True)
212212
enqueuer.start(3, 10)
@@ -219,7 +219,7 @@ def test_ordered_enqueuer_threads_not_ordered():
219219

220220

221221
def test_ordered_enqueuer_processes():
222-
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
222+
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
223223
enqueuer.start(3, 10)
224224
gen_output = enqueuer.get()
225225
acc = []

0 commit comments

Comments
 (0)