Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-124309: Modernize the staggered_race implementation to support eager task factories #124390

Merged
Merged
2 changes: 1 addition & 1 deletion Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ async def create_connection(
(functools.partial(self._connect_sock,
exceptions, addrinfo, laddr_infos)
for addrinfo in infos),
happy_eyeballs_delay, loop=self)
happy_eyeballs_delay)

if sock is None:
exceptions = [exc for sub in exceptions for exc in sub]
Expand Down
86 changes: 26 additions & 60 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import contextlib

from . import events
from . import exceptions as exceptions_mod
from . import locks
from . import tasks
from . import taskgroups

class _Done(Exception):
pass

async def staggered_race(coro_fns, delay, *, loop=None):
"""Run coroutines with staggered start times and take the first to finish.
Expand Down Expand Up @@ -42,8 +43,6 @@ async def staggered_race(coro_fns, delay, *, loop=None):
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.

loop: the event loop to use.

Returns:
tuple *(winner_result, winner_index, exceptions)* where

Expand All @@ -62,36 +61,20 @@ async def staggered_race(coro_fns, delay, *, loop=None):

"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
loop = loop or events.get_running_loop()
enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
running_tasks = []

async def run_one_coro(previous_failed) -> None:
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
# Use asyncio.wait_for() instead of asyncio.wait() here, so
# that if we get cancelled at this point, Event.wait() is also
# cancelled, otherwise there will be a "Task destroyed but it is
# pending" later.
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_task = loop.create_task(run_one_coro(this_failed))
running_tasks.append(next_task)
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
assert len(exceptions) == this_index + 1

if loop is not None:
import warnings
warnings._deprecated(
'loop',
'the {name!r} parameter is deprecated and slated for removal in '
'Python {remove}; it does nothing since 3.14',
ZeroIntensity marked this conversation as resolved.
Show resolved Hide resolved
remove=(3, 16),
)

async def run_one_coro(this_index, coro_fn, this_failed):
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
Expand All @@ -105,34 +88,17 @@ async def run_one_coro(previous_failed) -> None:
assert winner_index is None
winner_index = this_index
winner_result = result
# Cancel all other tasks. We take care to not cancel the current
# task as well. If we do so, then since there is no `await` after
# here and CancelledError are usually thrown at one, we will
# encounter a curious corner case where the current task will end
# up as done() == True, cancelled() == False, exception() ==
# asyncio.CancelledError. This behavior is specified in
# https://bugs.python.org/issue30048
for i, t in enumerate(running_tasks):
if i != this_index:
t.cancel()

first_task = loop.create_task(run_one_coro(None))
running_tasks.append(first_task)
raise _Done

try:
# Wait for a growing list of tasks to all finish: poor man's version of
# curio's TaskGroup or trio's nursery
done_count = 0
while done_count != len(running_tasks):
done, _ = await tasks.wait(running_tasks)
done_count = len(done)
# If run_one_coro raises an unhandled exception, it's probably a
# programming error, and I want to see it.
if __debug__:
for d in done:
if d.done() and not d.cancelled() and d.exception():
raise d.exception()
return winner_result, winner_index, exceptions
finally:
# Make sure no tasks are left running if we leave this function
for t in running_tasks:
t.cancel()
async with taskgroups.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = locks.Event()
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await tasks.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
45 changes: 45 additions & 0 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,51 @@ async def run():

self.run_coro(run())

# See GH-124309 for both of these
def test_staggered_race_with_eager_tasks(self):
async def fail():
await asyncio.sleep(0) # Dummy coroutine
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: asyncio.sleep(2),
lambda: asyncio.sleep(1),
lambda: fail()
],
delay=0.25
)
self.assertIsNone(winner)
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIsInstance(excs[2], ValueError)

self.run_coro(run())

def test_staggered_race_with_eager_tasks_no_delay(self):
async def fail():
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1),
lambda: asyncio.sleep(0),
],
delay=None
)
self.assertIsNone(winner)
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)

self.run_coro(run())



class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
Expand Down
37 changes: 33 additions & 4 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,45 @@ async def test_none_successful(self):
async def coro(index):
raise ValueError(index)

for delay in [None, 0, 0.1, 1]:
with self.subTest(delay=delay):
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=delay,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)

async def test_long_delay_early_failure(self):
async def coro(index):
await asyncio.sleep(0) # Dummy coroutine for the 1 case
if index == 0:
await asyncio.sleep(0.1) # Dummy coroutine
raise ValueError(index)

return f'Res: {index}'

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
delay=10,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)
self.assertIsNone(excs[1])


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_task` with :attr:`asyncio.eager_task_factory`.
Loading