From 2b54a4ebf1ffd2d709570267b3b3bb1dcfec4d53 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Thu, 26 Sep 2024 11:09:46 +0530 Subject: [PATCH] =?UTF-8?q?[3.12]=20gh-124309:=20Modernize=20the=20`stagge?= =?UTF-8?q?red=5Frace`=20implementation=20to=20support=20e=E2=80=A6=20(#12?= =?UTF-8?q?4574)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gh-124309: Modernize the `staggered_race` implementation to support eager task factories (#124390) Co-authored-by: Thomas Grainger Co-authored-by: Jelle Zijlstra Co-authored-by: Carol Willing Co-authored-by: Kumar Aditya (cherry picked from commit de929f353c413459834a2a37b2d9b0240673d874) Co-authored-by: Peter Bierma --- Lib/asyncio/base_events.py | 2 +- Lib/asyncio/staggered.py | 90 +++---------- .../test_asyncio/test_eager_task_factory.py | 47 +++++++ Lib/test/test_asyncio/test_staggered.py | 126 ++++++++++++++++++ ...-09-23-18-18-23.gh-issue-124309.iFcarA.rst | 1 + 5 files changed, 193 insertions(+), 73 deletions(-) create mode 100644 Lib/test/test_asyncio/test_staggered.py create mode 100644 Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index cb037fd472c5aa..02b900891949b3 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -1110,7 +1110,7 @@ async def create_connection( (functools.partial(self._connect_sock, exceptions, addrinfo, laddr_infos) for addrinfo in infos), - happy_eyeballs_delay, loop=self) + happy_eyeballs_delay) if sock is None: exceptions = [exc for sub in exceptions for exc in sub] diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py index 451a53a16f3831..4458d01dece0e6 100644 --- a/Lib/asyncio/staggered.py +++ b/Lib/asyncio/staggered.py @@ -3,24 +3,15 @@ __all__ = 'staggered_race', import contextlib -import typing -from . import events -from . import exceptions as exceptions_mod from . import locks from . import tasks +from . import taskgroups +class _Done(Exception): + pass -async def staggered_race( - coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], - delay: typing.Optional[float], - *, - loop: events.AbstractEventLoop = None, -) -> typing.Tuple[ - typing.Any, - typing.Optional[int], - typing.List[typing.Optional[Exception]] -]: +async def staggered_race(coro_fns, delay): """Run coroutines with staggered start times and take the first to finish. This method takes an iterable of coroutine functions. The first one is @@ -52,8 +43,6 @@ async def staggered_race( delay: amount of time, in seconds, between starting coroutines. If ``None``, the coroutines will run sequentially. - loop: the event loop to use. - Returns: tuple *(winner_result, winner_index, exceptions)* where @@ -72,37 +61,11 @@ async def staggered_race( """ # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. - loop = loop or events.get_running_loop() - enum_coro_fns = enumerate(coro_fns) winner_result = None winner_index = None exceptions = [] - running_tasks = [] - - async def run_one_coro( - previous_failed: typing.Optional[locks.Event]) -> None: - # Wait for the previous task to finish, or for delay seconds - if previous_failed is not None: - with contextlib.suppress(exceptions_mod.TimeoutError): - # Use asyncio.wait_for() instead of asyncio.wait() here, so - # that if we get cancelled at this point, Event.wait() is also - # cancelled, otherwise there will be a "Task destroyed but it is - # pending" later. - await tasks.wait_for(previous_failed.wait(), delay) - # Get the next coroutine to run - try: - this_index, coro_fn = next(enum_coro_fns) - except StopIteration: - return - # Start task that will run the next coroutine - this_failed = locks.Event() - next_task = loop.create_task(run_one_coro(this_failed)) - running_tasks.append(next_task) - assert len(running_tasks) == this_index + 2 - # Prepare place to put this coroutine's exceptions if not won - exceptions.append(None) - assert len(exceptions) == this_index + 1 + async def run_one_coro(this_index, coro_fn, this_failed): try: result = await coro_fn() except (SystemExit, KeyboardInterrupt): @@ -116,34 +79,17 @@ async def run_one_coro( assert winner_index is None winner_index = this_index winner_result = result - # Cancel all other tasks. We take care to not cancel the current - # task as well. If we do so, then since there is no `await` after - # here and CancelledError are usually thrown at one, we will - # encounter a curious corner case where the current task will end - # up as done() == True, cancelled() == False, exception() == - # asyncio.CancelledError. This behavior is specified in - # https://bugs.python.org/issue30048 - for i, t in enumerate(running_tasks): - if i != this_index: - t.cancel() - - first_task = loop.create_task(run_one_coro(None)) - running_tasks.append(first_task) + raise _Done + try: - # Wait for a growing list of tasks to all finish: poor man's version of - # curio's TaskGroup or trio's nursery - done_count = 0 - while done_count != len(running_tasks): - done, _ = await tasks.wait(running_tasks) - done_count = len(done) - # If run_one_coro raises an unhandled exception, it's probably a - # programming error, and I want to see it. - if __debug__: - for d in done: - if d.done() and not d.cancelled() and d.exception(): - raise d.exception() - return winner_result, winner_index, exceptions - finally: - # Make sure no tasks are left running if we leave this function - for t in running_tasks: - t.cancel() + async with taskgroups.TaskGroup() as tg: + for this_index, coro_fn in enumerate(coro_fns): + this_failed = locks.Event() + exceptions.append(None) + tg.create_task(run_one_coro(this_index, coro_fn, this_failed)) + with contextlib.suppress(TimeoutError): + await tasks.wait_for(this_failed.wait(), delay) + except* _Done: + pass + + return winner_result, winner_index, exceptions diff --git a/Lib/test/test_asyncio/test_eager_task_factory.py b/Lib/test/test_asyncio/test_eager_task_factory.py index 58c06287bc3c5d..ed74c6ecbd83f4 100644 --- a/Lib/test/test_asyncio/test_eager_task_factory.py +++ b/Lib/test/test_asyncio/test_eager_task_factory.py @@ -218,6 +218,53 @@ async def run(): self.run_coro(run()) + def test_staggered_race_with_eager_tasks(self): + # See https://github.com/python/cpython/issues/124309 + + async def fail(): + await asyncio.sleep(0) + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: asyncio.sleep(2, result="sleep2"), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: fail() + ], + delay=0.25 + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIsInstance(excs[2], ValueError) + + self.run_coro(run()) + + def test_staggered_race_with_eager_tasks_no_delay(self): + # See https://github.com/python/cpython/issues/124309 + async def fail(): + raise ValueError("no good") + + async def run(): + winner, index, excs = await asyncio.staggered.staggered_race( + [ + lambda: fail(), + lambda: asyncio.sleep(1, result="sleep1"), + lambda: asyncio.sleep(0, result="sleep0"), + ], + delay=None + ) + self.assertEqual(winner, 'sleep1') + self.assertEqual(index, 1) + self.assertIsNone(excs[index]) + self.assertIsInstance(excs[0], ValueError) + self.assertEqual(len(excs), 2) + + self.run_coro(run()) + + class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase): Task = tasks._PyTask diff --git a/Lib/test/test_asyncio/test_staggered.py b/Lib/test/test_asyncio/test_staggered.py new file mode 100644 index 00000000000000..21a39b3f911747 --- /dev/null +++ b/Lib/test/test_asyncio/test_staggered.py @@ -0,0 +1,126 @@ +import asyncio +import unittest +from asyncio.staggered import staggered_race + +from test import support + +support.requires_working_socket(module=True) + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class StaggeredTests(unittest.IsolatedAsyncioTestCase): + async def test_empty(self): + winner, index, excs = await staggered_race( + [], + delay=None, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(excs, []) + + async def test_one_successful(self): + async def coro(index): + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, 'Res: 0') + self.assertEqual(index, 0) + self.assertEqual(excs, [None]) + + async def test_first_error_second_successful(self): + async def coro(index): + if index == 0: + raise ValueError(index) + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=None, + ) + + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIs(excs[1], None) + + async def test_first_timeout_second_successful(self): + async def coro(index): + if index == 0: + await asyncio.sleep(10) # much bigger than delay + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=0.1, + ) + + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], asyncio.CancelledError) + self.assertIs(excs[1], None) + + async def test_none_successful(self): + async def coro(index): + raise ValueError(index) + + for delay in [None, 0, 0.1, 1]: + with self.subTest(delay=delay): + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=delay, + ) + + self.assertIs(winner, None) + self.assertIs(index, None) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsInstance(excs[1], ValueError) + + async def test_long_delay_early_failure(self): + async def coro(index): + await asyncio.sleep(0) # Dummy coroutine for the 1 case + if index == 0: + await asyncio.sleep(0.1) # Dummy coroutine + raise ValueError(index) + + return f'Res: {index}' + + winner, index, excs = await staggered_race( + [ + lambda: coro(0), + lambda: coro(1), + ], + delay=10, + ) + + self.assertEqual(winner, 'Res: 1') + self.assertEqual(index, 1) + self.assertEqual(len(excs), 2) + self.assertIsInstance(excs[0], ValueError) + self.assertIsNone(excs[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst new file mode 100644 index 00000000000000..89610fa44bf743 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-09-23-18-18-23.gh-issue-124309.iFcarA.rst @@ -0,0 +1 @@ +Fixed :exc:`AssertionError` when using :func:`!asyncio.staggered.staggered_race` with :attr:`asyncio.eager_task_factory`.