Skip to content

Commit

Permalink
Faster sequence (keras-team#8039)
Browse files Browse the repository at this point in the history
* Make Sequence faster

* Don't update if we're done

* Change according to review
  • Loading branch information
Frédéric Branchaud-Charron authored and fchollet committed Oct 2, 2017
1 parent ab57009 commit 19e1be2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
44 changes: 38 additions & 6 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,35 @@ def on_epoch_end(self):
"""
pass

# Global variables to be shared across processes
_SHARED_SEQUENCE = None
_MANAGER = multiprocessing.Manager()
_SHARED_DICT = _MANAGER.dict()

def get_index(ds, i):
"""Quick fix for Python2, otherwise, it cannot be pickled.

def get_index(i):
"""Get the value from the Sequence at index `i`.
# Arguments
ds: a Sequence object
i: index
# Returns
The value at index `i`.
"""
return ds[i]
global _SHARED_SEQUENCE
return _SHARED_SEQUENCE[i]


def _update_sequence(seq):
"""Update current process with a new Sequence.
# Arguments
seq: Sequence object
"""
global _SHARED_SEQUENCE, _SHARED_DICT
if not multiprocessing.current_process().pid in _SHARED_DICT:
_SHARED_SEQUENCE = seq
_SHARED_DICT[multiprocessing.current_process().pid] = 0


class SequenceEnqueuer(object):
Expand Down Expand Up @@ -477,6 +494,7 @@ def start(self, workers=1, max_queue_size=10):
self.executor = multiprocessing.Pool(workers)
else:
self.executor = ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
self.stop_signal = threading.Event()
self.run_thread = threading.Thread(target=self._run)
Expand All @@ -486,17 +504,18 @@ def start(self, workers=1, max_queue_size=10):
def _run(self):
"""Function to submit request to the executor and queue the `Future` objects."""
sequence = list(range(len(self.sequence)))
self._send_sequence() # Share the initial sequence
while True:
if self.shuffle:
random.shuffle(sequence)
for i in sequence:
if self.stop_signal.is_set():
return
self.queue.put(
self.executor.apply_async(get_index,
(self.sequence, i)), block=True)
self.executor.apply_async(get_index, (i,)), block=True)
# Call the internal on epoch end.
self.sequence.on_epoch_end()
self._send_sequence() # Update the pool

def get(self):
"""Creates a generator to extract data from the queue.
Expand All @@ -516,6 +535,19 @@ def get(self):
self.stop()
raise StopIteration(e)

def _send_sequence(self):
"""Send current Sequence to all workers."""
global _SHARED_SEQUENCE
_SHARED_SEQUENCE = self.sequence # For new processes that may spawn
if not self.use_multiprocessing:
# Threads are from the same process so they already share the sequence.
return
_SHARED_DICT.clear()
while len(_SHARED_DICT) < self.workers and not self.stop_signal.is_set():
# Ask the pool to update till everyone is updated.
self.executor.apply(_update_sequence, args=(self.sequence,))
# We're done with the update

def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.
Expand Down
12 changes: 6 additions & 6 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def g(*a, **kw):
return g


class TestSequence(Sequence):
class DummySequence(Sequence):
def __init__(self, shape):
self.shape = shape

Expand Down Expand Up @@ -149,7 +149,7 @@ def create_generator_from_sequence_pcs(ds):

def test_generator_enqueuer_threads():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_threads(
TestSequence([3, 200, 200, 3])), use_multiprocessing=False)
DummySequence([3, 200, 200, 3])), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand All @@ -166,7 +166,7 @@ def test_generator_enqueuer_threads():

def test_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
TestSequence([3, 200, 200, 3])), use_multiprocessing=True)
DummySequence([3, 200, 200, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_generator_enqueuer_fail_processes():


def test_ordered_enqueuer_threads():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand All @@ -206,7 +206,7 @@ def test_ordered_enqueuer_threads():


def test_ordered_enqueuer_threads_not_ordered():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]),
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]),
use_multiprocessing=False,
shuffle=True)
enqueuer.start(3, 10)
Expand All @@ -219,7 +219,7 @@ def test_ordered_enqueuer_threads_not_ordered():


def test_ordered_enqueuer_processes():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand Down

0 comments on commit 19e1be2

Please sign in to comment.