Skip to content

Commit 46ea84e

Browse files
committed
bpo-29842: Introduce a prefetch parameter to Executor.map to handle large iterators
1 parent 28bb296 commit 46ea84e

File tree

5 files changed

+69
-18
lines changed

5 files changed

+69
-18
lines changed

Doc/library/concurrent.futures.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ Executor Objects
3939
future = executor.submit(pow, 323, 1235)
4040
print(future.result())
4141

42-
.. method:: map(fn, *iterables, timeout=None, chunksize=1)
42+
.. method:: map(fn, *iterables, timeout=None, chunksize=1, prefetch=None)
4343

4444
Similar to :func:`map(fn, *iterables) <map>` except:
4545

@@ -65,9 +65,16 @@ Executor Objects
6565
performance compared to the default size of 1. With
6666
:class:`ThreadPoolExecutor`, *chunksize* has no effect.
6767

68+
By default, all tasks are queued. An explicit *prefetch* count may be
69+
provided to specify how many extra tasks, beyond the number of workers,
70+
should be queued.
71+
6872
.. versionchanged:: 3.5
6973
Added the *chunksize* argument.
7074

75+
.. versionchanged:: 3.13
76+
Added the *prefetch* argument.
77+
7178
.. method:: shutdown(wait=True, *, cancel_futures=False)
7279

7380
Signal the executor that it should free any resources that it is using

Lib/concurrent/futures/_base.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import threading
99
import time
1010
import types
11+
import weakref
1112

1213
FIRST_COMPLETED = 'FIRST_COMPLETED'
1314
FIRST_EXCEPTION = 'FIRST_EXCEPTION'
@@ -569,6 +570,15 @@ def set_exception(self, exception):
569570
class Executor(object):
570571
"""This is an abstract base class for concrete asynchronous executors."""
571572

573+
def __init__(self, max_workers=None):
574+
"""Initializes a new Executor instance.
575+
576+
Args:
577+
max_workers: The maximum number of workers that can be used to
578+
execute the given calls.
579+
"""
580+
self._max_workers = max_workers
581+
572582
def submit(self, fn, /, *args, **kwargs):
573583
"""Submits a callable to be executed with the given arguments.
574584
@@ -580,7 +590,7 @@ def submit(self, fn, /, *args, **kwargs):
580590
"""
581591
raise NotImplementedError()
582592

583-
def map(self, fn, *iterables, timeout=None, chunksize=1):
593+
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
584594
"""Returns an iterator equivalent to map(fn, iter).
585595
586596
Args:
@@ -592,6 +602,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
592602
before being passed to a child process. This argument is only
593603
used by ProcessPoolExecutor; it is ignored by
594604
ThreadPoolExecutor.
605+
prefetch: The number of chunks to queue beyond the number of
606+
workers on the executor. If None, all chunks are queued.
595607
596608
Returns:
597609
An iterator equivalent to: map(func, *iterables) but the calls may
@@ -604,25 +616,44 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
604616
"""
605617
if timeout is not None:
606618
end_time = timeout + time.monotonic()
619+
if prefetch is not None and prefetch < 0:
620+
raise ValueError("prefetch count may not be negative")
607621

608-
fs = [self.submit(fn, *args) for args in zip(*iterables)]
622+
all_args = zip(*iterables)
623+
if prefetch is None:
624+
fs = collections.deque(self.submit(fn, *args) for args in all_args)
625+
else:
626+
fs = collections.deque()
627+
for idx, args in enumerate(all_args):
628+
if idx >= self._max_workers + prefetch:
629+
break
630+
fs.append(self.submit(fn, *args))
609631

610632
# Yield must be hidden in closure so that the futures are submitted
611633
# before the first iterator value is required.
612-
def result_iterator():
634+
def result_iterator(all_args, executor_ref):
613635
try:
614-
# reverse to keep finishing order
615-
fs.reverse()
616636
while fs:
617637
# Careful not to keep a reference to the popped future
618638
if timeout is None:
619-
yield _result_or_cancel(fs.pop())
639+
yield _result_or_cancel(fs.popleft())
620640
else:
621-
yield _result_or_cancel(fs.pop(), end_time - time.monotonic())
641+
yield _result_or_cancel(
642+
fs.popleft(), end_time - time.monotonic()
643+
)
644+
645+
# Submit the next task if any and if the executor exists
646+
if executor_ref():
647+
try:
648+
args = next(all_args)
649+
except StopIteration:
650+
pass
651+
else:
652+
fs.append(executor_ref().submit(fn, *args))
622653
finally:
623654
for future in fs:
624655
future.cancel()
625-
return result_iterator()
656+
return result_iterator(all_args, weakref.ref(self))
626657

627658
def shutdown(self, wait=True, *, cancel_futures=False):
628659
"""Clean-up the resources associated with the Executor.

Lib/concurrent/futures/process.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -656,19 +656,17 @@ def __init__(self, max_workers=None, mp_context=None,
656656
_check_system_limits()
657657

658658
if max_workers is None:
659-
self._max_workers = os.process_cpu_count() or 1
659+
max_workers = os.process_cpu_count() or 1
660660
if sys.platform == 'win32':
661-
self._max_workers = min(_MAX_WINDOWS_WORKERS,
662-
self._max_workers)
661+
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)
663662
else:
664663
if max_workers <= 0:
665664
raise ValueError("max_workers must be greater than 0")
666665
elif (sys.platform == 'win32' and
667-
max_workers > _MAX_WINDOWS_WORKERS):
666+
max_workers > _MAX_WINDOWS_WORKERS):
668667
raise ValueError(
669668
f"max_workers must be <= {_MAX_WINDOWS_WORKERS}")
670-
671-
self._max_workers = max_workers
669+
super().__init__(max_workers)
672670

673671
if mp_context is None:
674672
if max_tasks_per_child is not None:
@@ -812,7 +810,7 @@ def submit(self, fn, /, *args, **kwargs):
812810
return f
813811
submit.__doc__ = _base.Executor.submit.__doc__
814812

815-
def map(self, fn, *iterables, timeout=None, chunksize=1):
813+
def map(self, fn, *iterables, timeout=None, chunksize=1, prefetch=None):
816814
"""Returns an iterator equivalent to map(fn, iter).
817815
818816
Args:
@@ -823,6 +821,8 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
823821
chunksize: If greater than one, the iterables will be chopped into
824822
chunks of size chunksize and submitted to the process pool.
825823
If set to one, the items in the list will be sent one at a time.
824+
prefetch: The number of chunks to queue beyond the number of
825+
workers on the executor. If None, all chunks are queued.
826826
827827
Returns:
828828
An iterator equivalent to: map(func, *iterables) but the calls may
@@ -838,7 +838,7 @@ def map(self, fn, *iterables, timeout=None, chunksize=1):
838838

839839
results = super().map(partial(_process_chunk, fn),
840840
itertools.batched(zip(*iterables), chunksize),
841-
timeout=timeout)
841+
timeout=timeout, prefetch=prefetch)
842842
return _chain_from_iterable_of_lists(results)
843843

844844
def shutdown(self, wait=True, *, cancel_futures=False):

Lib/concurrent/futures/thread.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(self, max_workers=None, thread_name_prefix='',
149149
if initializer is not None and not callable(initializer):
150150
raise TypeError("initializer must be a callable")
151151

152-
self._max_workers = max_workers
152+
super().__init__(max_workers)
153153
self._work_queue = queue.SimpleQueue()
154154
self._idle_semaphore = threading.Semaphore(0)
155155
self._threads = set()

Lib/test/test_concurrent_futures/test_thread_pool.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ def record_finished(n):
2323
self.executor.shutdown(wait=True)
2424
self.assertCountEqual(finished, range(10))
2525

26+
def test_map_on_infinite_iterator(self):
27+
import itertools
28+
def identity(x):
29+
return x
30+
31+
mapobj = self.executor.map(identity, itertools.count(0), prefetch=1)
32+
# Get one result, which shows we handle infinite inputs
33+
# without waiting for all work to be dispatched
34+
res = next(mapobj)
35+
mapobj.close() # Make sure futures cancelled
36+
37+
self.assertEqual(res, 0)
38+
2639
def test_default_workers(self):
2740
executor = self.executor_type()
2841
expected = min(32, (os.process_cpu_count() or 1) + 4)

0 commit comments

Comments
 (0)