Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster sequence #8039

Merged
merged 3 commits into from
Oct 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,35 @@ def on_epoch_end(self):
"""
pass

# Global variables to be shared across processes
_SHARED_SEQUENCE = None
_MANAGER = multiprocessing.Manager()
_SHARED_DICT = _MANAGER.dict()

def get_index(ds, i):
"""Quick fix for Python2, otherwise, it cannot be pickled.

def get_index(i):
"""Get the value from the Sequence at index `i`.

# Arguments
ds: a Sequence object
i: index

# Returns
The value at index `i`.
"""
return ds[i]
global _SHARED_SEQUENCE
return _SHARED_SEQUENCE[i]


def _update_sequence(seq):
"""Update current process with a new Sequence.

# Arguments
seq: Sequence object
"""
global _SHARED_SEQUENCE, _SHARED_DICT
if not multiprocessing.current_process().pid in _SHARED_DICT:
_SHARED_SEQUENCE = seq
_SHARED_DICT[multiprocessing.current_process().pid] = 0


class SequenceEnqueuer(object):
Expand Down Expand Up @@ -477,6 +494,7 @@ def start(self, workers=1, max_queue_size=10):
self.executor = multiprocessing.Pool(workers)
else:
self.executor = ThreadPool(workers)
self.workers = workers
self.queue = queue.Queue(max_queue_size)
self.stop_signal = threading.Event()
self.run_thread = threading.Thread(target=self._run)
Expand All @@ -486,17 +504,18 @@ def start(self, workers=1, max_queue_size=10):
def _run(self):
"""Function to submit request to the executor and queue the `Future` objects."""
sequence = list(range(len(self.sequence)))
self._send_sequence() # Share the initial sequence
while True:
if self.shuffle:
random.shuffle(sequence)
for i in sequence:
if self.stop_signal.is_set():
return
self.queue.put(
self.executor.apply_async(get_index,
(self.sequence, i)), block=True)
self.executor.apply_async(get_index, (i,)), block=True)
# Call the internal on epoch end.
self.sequence.on_epoch_end()
self._send_sequence() # Update the pool

def get(self):
"""Creates a generator to extract data from the queue.
Expand All @@ -516,6 +535,19 @@ def get(self):
self.stop()
raise StopIteration(e)

def _send_sequence(self):
"""Send current Sequence to all workers."""
global _SHARED_SEQUENCE
_SHARED_SEQUENCE = self.sequence # For new processes that may spawn
if not self.use_multiprocessing:
# Threads are from the same process so they already share the sequence.
return
_SHARED_DICT.clear()
while len(_SHARED_DICT) < self.workers and not self.stop_signal.is_set():
# Ask the pool to update till everyone is updated.
self.executor.apply(_update_sequence, args=(self.sequence,))
# We're done with the update

def stop(self, timeout=None):
"""Stops running threads and wait for them to exit, if necessary.

Expand Down
12 changes: 6 additions & 6 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def g(*a, **kw):
return g


class TestSequence(Sequence):
class DummySequence(Sequence):
def __init__(self, shape):
self.shape = shape

Expand Down Expand Up @@ -149,7 +149,7 @@ def create_generator_from_sequence_pcs(ds):

def test_generator_enqueuer_threads():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_threads(
TestSequence([3, 200, 200, 3])), use_multiprocessing=False)
DummySequence([3, 200, 200, 3])), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand All @@ -166,7 +166,7 @@ def test_generator_enqueuer_threads():

def test_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_generator_from_sequence_pcs(
TestSequence([3, 200, 200, 3])), use_multiprocessing=True)
DummySequence([3, 200, 200, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_generator_enqueuer_fail_processes():


def test_ordered_enqueuer_threads():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=False)
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand All @@ -206,7 +206,7 @@ def test_ordered_enqueuer_threads():


def test_ordered_enqueuer_threads_not_ordered():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]),
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]),
use_multiprocessing=False,
shuffle=True)
enqueuer.start(3, 10)
Expand All @@ -219,7 +219,7 @@ def test_ordered_enqueuer_threads_not_ordered():


def test_ordered_enqueuer_processes():
enqueuer = OrderedEnqueuer(TestSequence([3, 200, 200, 3]), use_multiprocessing=True)
enqueuer = OrderedEnqueuer(DummySequence([3, 200, 200, 3]), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
Expand Down