Skip to content

Commit

Permalink
[3.12] pythongh-124309: Modernize the staggered_race implementation…
Browse files Browse the repository at this point in the history
… to support e… (python#124574)

pythongh-124309: Modernize the `staggered_race` implementation to support eager task factories (python#124390)

Co-authored-by: Thomas Grainger <tagrain@gmail.com>
Co-authored-by: Jelle Zijlstra <jelle.zijlstra@gmail.com>
Co-authored-by: Carol Willing <carolcode@willingconsulting.com>
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
(cherry picked from commit de929f3)

Co-authored-by: Peter Bierma <zintensitydev@gmail.com>
  • Loading branch information
kumaraditya303 and ZeroIntensity authored Sep 26, 2024
1 parent 48359c5 commit 2b54a4e
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 73 deletions.
2 changes: 1 addition & 1 deletion Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ async def create_connection(
(functools.partial(self._connect_sock,
exceptions, addrinfo, laddr_infos)
for addrinfo in infos),
happy_eyeballs_delay, loop=self)
happy_eyeballs_delay)

if sock is None:
exceptions = [exc for sub in exceptions for exc in sub]
Expand Down
90 changes: 18 additions & 72 deletions Lib/asyncio/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@
__all__ = 'staggered_race',

import contextlib
import typing

from . import events
from . import exceptions as exceptions_mod
from . import locks
from . import tasks
from . import taskgroups

class _Done(Exception):
pass

async def staggered_race(
coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
delay: typing.Optional[float],
*,
loop: events.AbstractEventLoop = None,
) -> typing.Tuple[
typing.Any,
typing.Optional[int],
typing.List[typing.Optional[Exception]]
]:
async def staggered_race(coro_fns, delay):
"""Run coroutines with staggered start times and take the first to finish.
This method takes an iterable of coroutine functions. The first one is
Expand Down Expand Up @@ -52,8 +43,6 @@ async def staggered_race(
delay: amount of time, in seconds, between starting coroutines. If
``None``, the coroutines will run sequentially.
loop: the event loop to use.
Returns:
tuple *(winner_result, winner_index, exceptions)* where
Expand All @@ -72,37 +61,11 @@ async def staggered_race(
"""
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
loop = loop or events.get_running_loop()
enum_coro_fns = enumerate(coro_fns)
winner_result = None
winner_index = None
exceptions = []
running_tasks = []

async def run_one_coro(
previous_failed: typing.Optional[locks.Event]) -> None:
# Wait for the previous task to finish, or for delay seconds
if previous_failed is not None:
with contextlib.suppress(exceptions_mod.TimeoutError):
# Use asyncio.wait_for() instead of asyncio.wait() here, so
# that if we get cancelled at this point, Event.wait() is also
# cancelled, otherwise there will be a "Task destroyed but it is
# pending" later.
await tasks.wait_for(previous_failed.wait(), delay)
# Get the next coroutine to run
try:
this_index, coro_fn = next(enum_coro_fns)
except StopIteration:
return
# Start task that will run the next coroutine
this_failed = locks.Event()
next_task = loop.create_task(run_one_coro(this_failed))
running_tasks.append(next_task)
assert len(running_tasks) == this_index + 2
# Prepare place to put this coroutine's exceptions if not won
exceptions.append(None)
assert len(exceptions) == this_index + 1

async def run_one_coro(this_index, coro_fn, this_failed):
try:
result = await coro_fn()
except (SystemExit, KeyboardInterrupt):
Expand All @@ -116,34 +79,17 @@ async def run_one_coro(
assert winner_index is None
winner_index = this_index
winner_result = result
# Cancel all other tasks. We take care to not cancel the current
# task as well. If we do so, then since there is no `await` after
# here and CancelledError are usually thrown at one, we will
# encounter a curious corner case where the current task will end
# up as done() == True, cancelled() == False, exception() ==
# asyncio.CancelledError. This behavior is specified in
# https://bugs.python.org/issue30048
for i, t in enumerate(running_tasks):
if i != this_index:
t.cancel()

first_task = loop.create_task(run_one_coro(None))
running_tasks.append(first_task)
raise _Done

try:
# Wait for a growing list of tasks to all finish: poor man's version of
# curio's TaskGroup or trio's nursery
done_count = 0
while done_count != len(running_tasks):
done, _ = await tasks.wait(running_tasks)
done_count = len(done)
# If run_one_coro raises an unhandled exception, it's probably a
# programming error, and I want to see it.
if __debug__:
for d in done:
if d.done() and not d.cancelled() and d.exception():
raise d.exception()
return winner_result, winner_index, exceptions
finally:
# Make sure no tasks are left running if we leave this function
for t in running_tasks:
t.cancel()
async with taskgroups.TaskGroup() as tg:
for this_index, coro_fn in enumerate(coro_fns):
this_failed = locks.Event()
exceptions.append(None)
tg.create_task(run_one_coro(this_index, coro_fn, this_failed))
with contextlib.suppress(TimeoutError):
await tasks.wait_for(this_failed.wait(), delay)
except* _Done:
pass

return winner_result, winner_index, exceptions
47 changes: 47 additions & 0 deletions Lib/test/test_asyncio/test_eager_task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,53 @@ async def run():

self.run_coro(run())

def test_staggered_race_with_eager_tasks(self):
# See https://github.com/python/cpython/issues/124309

async def fail():
await asyncio.sleep(0)
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: asyncio.sleep(2, result="sleep2"),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: fail()
],
delay=0.25
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIsInstance(excs[2], ValueError)

self.run_coro(run())

def test_staggered_race_with_eager_tasks_no_delay(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
raise ValueError("no good")

async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: asyncio.sleep(0, result="sleep0"),
],
delay=None
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)

self.run_coro(run())



class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
Expand Down
126 changes: 126 additions & 0 deletions Lib/test/test_asyncio/test_staggered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio
import unittest
from asyncio.staggered import staggered_race

from test import support

support.requires_working_socket(module=True)


def tearDownModule():
asyncio.set_event_loop_policy(None)


class StaggeredTests(unittest.IsolatedAsyncioTestCase):
async def test_empty(self):
winner, index, excs = await staggered_race(
[],
delay=None,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(excs, [])

async def test_one_successful(self):
async def coro(index):
return f'Res: {index}'

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
)

self.assertEqual(winner, 'Res: 0')
self.assertEqual(index, 0)
self.assertEqual(excs, [None])

async def test_first_error_second_successful(self):
async def coro(index):
if index == 0:
raise ValueError(index)
return f'Res: {index}'

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
)

self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIs(excs[1], None)

async def test_first_timeout_second_successful(self):
async def coro(index):
if index == 0:
await asyncio.sleep(10) # much bigger than delay
return f'Res: {index}'

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=0.1,
)

self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIs(excs[1], None)

async def test_none_successful(self):
async def coro(index):
raise ValueError(index)

for delay in [None, 0, 0.1, 1]:
with self.subTest(delay=delay):
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=delay,
)

self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)

async def test_long_delay_early_failure(self):
async def coro(index):
await asyncio.sleep(0) # Dummy coroutine for the 1 case
if index == 0:
await asyncio.sleep(0.1) # Dummy coroutine
raise ValueError(index)

return f'Res: {index}'

winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=10,
)

self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsNone(excs[1])


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.

0 comments on commit 2b54a4e

Please sign in to comment.