Skip to content

Commit

Permalink
Extending GeneratorEnqueuer to handle finite generators. (#8104)
Browse files Browse the repository at this point in the history
* Extending GeneratorEnqueuer to handle finite generators.

* Fixing coding styles.

* Fixing coding style.

* Fixing Docstring styles.

* Adding test for the case of finite generator.

* Removing trailing space on docstring

* Trimming extra spaces.
  • Loading branch information
datumbox authored and fchollet committed Oct 16, 2017
1 parent 7610c55 commit f940d22
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
13 changes: 11 additions & 2 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,13 @@ def stop(self, timeout=None):
class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.
The provided generator can be finite in which case the class will throw
a `StopIteration` exception.
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
# Arguments
generator: a generator function which endlessly yields data
generator: a generator function which yields data
use_multiprocessing: use multiprocessing if True, otherwise threading
wait_time: time to sleep in-between calls to `put()`
random_seed: Initial seed for workers,
Expand Down Expand Up @@ -576,6 +579,8 @@ def data_generator_task():
self.queue.put(generator_output)
else:
time.sleep(self.wait_time)
except StopIteration:
break
except Exception:
self._stop_event.set()
raise
Expand Down Expand Up @@ -648,4 +653,8 @@ def get(self):
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
all_finished = all([not thread.is_alive() for thread in self._threads])
if all_finished:
raise StopIteration()
else:
time.sleep(self.wait_time)
35 changes: 35 additions & 0 deletions tests/keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,40 @@ def test_ordered_enqueuer_fail_processes():
next(gen_output)


@threadsafe_generator
def create_finite_generator_from_sequence_threads(ds):
for i in range(len(ds)):
yield ds[i]


def create_finite_generator_from_sequence_pcs(ds):
for i in range(len(ds)):
yield ds[i]


def test_finite_generator_enqueuer_threads():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_threads(
TestSequence([3, 200, 200, 3])), use_multiprocessing=False)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
for output in gen_output:
acc.append(int(output[0, 0, 0, 0]))
assert len(set(acc) - set(range(100))) == 0, "Output is not the same"
enqueuer.stop()


def test_finite_generator_enqueuer_processes():
enqueuer = GeneratorEnqueuer(create_finite_generator_from_sequence_pcs(
TestSequence([3, 200, 200, 3])), use_multiprocessing=True)
enqueuer.start(3, 10)
gen_output = enqueuer.get()
acc = []
for output in gen_output:
acc.append(int(output[0, 0, 0, 0]))
assert acc != list(range(100)), "Order was keep in GeneratorEnqueuer with processes"
enqueuer.stop()


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit f940d22

Please sign in to comment.