Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending GeneratorEnqueuer to handle finite generators. #8104

Merged
merged 7 commits into from
Oct 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,13 @@ def stop(self, timeout=None):
class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.

The provided generator can be finite in which case the class will throw
a `StopIteration` exception.

Used in `fit_generator`, `evaluate_generator`, `predict_generator`.

# Arguments
generator: a generator function which endlessly yields data
generator: a generator function which yields data
use_multiprocessing: use multiprocessing if True, otherwise threading
wait_time: time to sleep in-between calls to `put()`
random_seed: Initial seed for workers,
Expand Down Expand Up @@ -576,6 +579,8 @@ def data_generator_task():
self.queue.put(generator_output)
else:
time.sleep(self.wait_time)
except StopIteration:
break
except Exception:
self._stop_event.set()
raise
Expand Down Expand Up @@ -648,4 +653,8 @@ def get(self):
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
all_finished = all([not thread.is_alive() for thread in self._threads])
if all_finished:
raise StopIteration()
else:
time.sleep(self.wait_time)
35 changes: 35 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,40 @@ def test_ordered_enqueuer_fail_processes():
next(gen_output)


@threadsafe_generator
def create_finite_generator_from_sequence_threads(ds):
for i in range(len(ds)):
yield ds[i]


def create_finite_generator_from_sequence_pcs(ds):
for i in range(len(ds)):
yield ds[i]


def test_finite_generator_enqueuer_threads():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_threads(
TestSequence([3, 200, 200, 3])), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
for output in gen_output:
acc.append(int(output[0, 0, 0, 0]))
assert len(set(acc) - set(range(100))) == 0, "Output is not the same"
enqueuer.stop()


def test_finite_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dref360 Is not this tested directly in the for loop? The StopIterator causes the for loop to exit. Do you mean place a next() call after the end of the loop to show it throws an exception?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah you're right sorry.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, thanks so much for taking the time to read it. You can never be sure enough, especially when you change the default behaviour of a class. :)

TestSequence([3, 200, 200, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
for output in gen_output:
acc.append(int(output[0, 0, 0, 0]))
assert acc != list(range(100)), "Order was keep in GeneratorEnqueuer with processes"
enqueuer.stop()


if __name__ == '__main__':
pytest.main([__file__])