Skip to content

Commit 71962ce

Browse files
committed
Add basic eager async evaluation to Tasks (Python only) and TaskGroups
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 5f9c0f5 commit 71962ce

File tree

3 files changed

+204
-45
lines changed

3 files changed

+204
-45
lines changed

Lib/asyncio/taskgroups.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
from . import events
99
from . import exceptions
1010
from . import tasks
11+
from .futures import Future
12+
13+
14+
class PyCoroEagerResult:
15+
def __init__(self, value):
16+
self.value = value
1117

1218

1319
class TaskGroup:
@@ -24,6 +30,7 @@ def __init__(self):
2430
self._errors = []
2531
self._base_error = None
2632
self._on_completed_fut = None
33+
self._enqueues = {}
2734

2835
def __repr__(self):
2936
info = ['']
@@ -57,6 +64,16 @@ async def __aenter__(self):
5764

5865
return self
5966

67+
def eager_eval(self, coro):
68+
try:
69+
fut = coro.send(None)
70+
except StopIteration as e:
71+
return PyCoroEagerResult(e.args[0] if e.args else None)
72+
else:
73+
task = self.create_task(coro)
74+
task._set_fut_awaiter(fut)
75+
return task
76+
6077
async def __aexit__(self, et, exc, tb):
6178
self._exiting = True
6279
propagate_cancellation_error = None
@@ -89,6 +106,23 @@ async def __aexit__(self, et, exc, tb):
89106
#
90107
self._abort()
91108

109+
if self._enqueues:
110+
for coro in self._enqueues:
111+
res = self.eager_eval(coro)
112+
if isinstance(res, PyCoroEagerResult):
113+
fut = self._enqueues[coro]
114+
if fut is not None:
115+
fut.set_result(res.value)
116+
else:
117+
def queue_callback():
118+
fut = self._enqueues[coro]
119+
if fut is not None:
120+
res.add_done_callback(lambda task: fut.set_result(task.result()))
121+
queue_callback()
122+
123+
self._unfinished_tasks -= len(self._enqueues)
124+
self._enqueues.clear()
125+
92126
# We use while-loop here because "self._on_completed_fut"
93127
# can be cancelled multiple times if our parent task
94128
# is being cancelled repeatedly (or even once, when
@@ -153,6 +187,31 @@ def create_task(self, coro, *, name=None, context=None):
153187
self._tasks.add(task)
154188
return task
155189

190+
def enqueue(self, coro, no_future=True):
191+
if not self._entered:
192+
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
193+
194+
if coro in self._enqueues:
195+
return self._enqueues[coro]
196+
197+
if not self._enqueues:
198+
# if the looop starts running because someone awaits, we want
199+
# to run the co-routines which are enqueued as well.
200+
self._loop.call_soon(self._enqueue_enqueus)
201+
202+
self._unfinished_tasks += 1
203+
if no_future:
204+
fut = self._enqueues[coro] = None
205+
else:
206+
fut = self._enqueues[coro] = Future(loop=self._loop)
207+
return fut
208+
209+
def _enqueue_enqueus(self):
210+
for coro in self._enqueues:
211+
self.create_task(coro)
212+
self._unfinished_tasks -= len(self._enqueues)
213+
self._enqueues.clear()
214+
156215
# Since Python 3.8 Tasks propagate all exceptions correctly,
157216
# except for KeyboardInterrupt and SystemExit which are
158217
# still considered special.

Lib/asyncio/tasks.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def __init__(self, coro, *, loop=None, name=None, context=None):
117117
else:
118118
self._context = context
119119

120-
self._loop.call_soon(self.__step, context=self._context)
120+
if not getattr(coro, "cr_suspended", False):
121+
self._loop.call_soon(self.__step, context=self._context)
122+
121123
_register_task(self)
122124

123125
def __del__(self):
@@ -293,55 +295,58 @@ def __step(self, exc=None):
293295
except BaseException as exc:
294296
super().set_exception(exc)
295297
else:
296-
blocking = getattr(result, '_asyncio_future_blocking', None)
297-
if blocking is not None:
298-
# Yielded Future must come from Future.__iter__().
299-
if not self._check_future(result):
300-
new_exc = RuntimeError(
301-
f'Task {self!r} got Future '
302-
f'{result!r} attached to a different loop')
303-
self._loop.call_soon(
304-
self.__step, new_exc, context=self._context)
305-
elif blocking:
306-
if result is self:
307-
new_exc = RuntimeError(
308-
f'Task cannot await on itself: {self!r}')
309-
self._loop.call_soon(
310-
self.__step, new_exc, context=self._context)
311-
else:
312-
result._asyncio_future_blocking = False
313-
result.add_done_callback(
314-
self.__wakeup, context=self._context)
315-
self._fut_waiter = result
316-
if self._must_cancel:
317-
if self._fut_waiter.cancel(
318-
msg=self._cancel_message):
319-
self._must_cancel = False
320-
else:
321-
new_exc = RuntimeError(
322-
f'yield was used instead of yield from '
323-
f'in task {self!r} with {result!r}')
324-
self._loop.call_soon(
325-
self.__step, new_exc, context=self._context)
298+
self._set_fut_awaiter(result)
299+
finally:
300+
_leave_task(self._loop, self)
301+
self = None # Needed to break cycles when an exception occurs.
326302

327-
elif result is None:
328-
# Bare yield relinquishes control for one event loop iteration.
329-
self._loop.call_soon(self.__step, context=self._context)
330-
elif inspect.isgenerator(result):
331-
# Yielding a generator is just wrong.
303+
def _set_fut_awaiter(self, result):
304+
blocking = getattr(result, '_asyncio_future_blocking', None)
305+
if blocking is not None:
306+
# Yielded Future must come from Future.__iter__().
307+
if not self._check_future(result):
332308
new_exc = RuntimeError(
333-
f'yield was used instead of yield from for '
334-
f'generator in task {self!r} with {result!r}')
309+
f'Task {self!r} got Future '
310+
f'{result!r} attached to a different loop')
335311
self._loop.call_soon(
336312
self.__step, new_exc, context=self._context)
313+
elif blocking:
314+
if result is self:
315+
new_exc = RuntimeError(
316+
f'Task cannot await on itself: {self!r}')
317+
self._loop.call_soon(
318+
self.__step, new_exc, context=self._context)
319+
else:
320+
result._asyncio_future_blocking = False
321+
result.add_done_callback(
322+
self.__wakeup, context=self._context)
323+
self._fut_waiter = result
324+
if self._must_cancel:
325+
if self._fut_waiter.cancel(
326+
msg=self._cancel_message):
327+
self._must_cancel = False
337328
else:
338-
# Yielding something else is an error.
339-
new_exc = RuntimeError(f'Task got bad yield: {result!r}')
329+
new_exc = RuntimeError(
330+
f'yield was used instead of yield from '
331+
f'in task {self!r} with {result!r}')
340332
self._loop.call_soon(
341333
self.__step, new_exc, context=self._context)
342-
finally:
343-
_leave_task(self._loop, self)
344-
self = None # Needed to break cycles when an exception occurs.
334+
335+
elif result is None:
336+
# Bare yield relinquishes control for one event loop iteration.
337+
self._loop.call_soon(self.__step, context=self._context)
338+
elif inspect.isgenerator(result):
339+
# Yielding a generator is just wrong.
340+
new_exc = RuntimeError(
341+
f'yield was used instead of yield from for '
342+
f'generator in task {self!r} with {result!r}')
343+
self._loop.call_soon(
344+
self.__step, new_exc, context=self._context)
345+
else:
346+
# Yielding something else is an error.
347+
new_exc = RuntimeError(f'Task got bad yield: {result!r}')
348+
self._loop.call_soon(
349+
self.__step, new_exc, context=self._context)
345350

346351
def __wakeup(self, future):
347352
try:
@@ -369,8 +374,8 @@ def __wakeup(self, future):
369374
pass
370375
else:
371376
# _CTask is needed for tests.
372-
Task = _CTask = _asyncio.Task
373-
377+
#Task = _CTask = _asyncio.Task
378+
pass
374379

375380
def create_task(coro, *, name=None, context=None):
376381
"""Schedule the execution of a coroutine object in a spawn task.

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,101 @@ async def coro(val):
722722
await t2
723723
self.assertEqual(2, ctx.get(cvar))
724724

725+
async def test_taskgroup_enqueue_01(self):
726+
727+
async def foo1():
728+
await asyncio.sleep(0.1)
729+
return 42
730+
731+
async def eager():
732+
return 11
733+
734+
async with taskgroups.TaskGroup() as g:
735+
t1 = g.enqueue(foo1(), no_future=False)
736+
t2 = g.enqueue(eager(), no_future=False)
737+
738+
self.assertEqual(t1.result(), 42)
739+
self.assertEqual(t2.result(), 11)
740+
741+
async def test_taskgroup_enqueue_02(self):
742+
743+
async def foo1():
744+
return 42
745+
746+
async def eager():
747+
return 11
748+
749+
async with taskgroups.TaskGroup() as g:
750+
t1 = g.enqueue(foo1(), no_future=False)
751+
t2 = g.enqueue(eager(), no_future=False)
752+
753+
self.assertEqual(t1.result(), 42)
754+
self.assertEqual(t2.result(), 11)
755+
756+
async def test_taskgroup_fanout_task(self):
757+
async def step(i):
758+
if i == 0:
759+
return
760+
async with taskgroups.TaskGroup() as g:
761+
for _ in range(6):
762+
g.create_task(step(i - 1))
763+
764+
import time
765+
s = time.perf_counter()
766+
await step(6)
767+
e = time.perf_counter()
768+
print(e-s)
769+
770+
async def test_taskgroup_fanout_enqueue(self):
771+
async def step(i):
772+
if i == 0:
773+
return
774+
async with taskgroups.TaskGroup() as g:
775+
for _ in range(6):
776+
g.enqueue(step(i - 1))
777+
778+
import time
779+
s = time.perf_counter()
780+
await step(6)
781+
e = time.perf_counter()
782+
print(e-s)
783+
784+
async def test_taskgroup_fanout_enqueue_02(self):
785+
async def intermediate2(i):
786+
return await intermediate(i)
787+
788+
async def intermediate(i):
789+
async with taskgroups.TaskGroup() as g:
790+
for _ in range(6):
791+
g.enqueue(step(i - 1))
792+
793+
async def step(i):
794+
if i == 0:
795+
return
796+
797+
return await intermediate2(i)
798+
799+
800+
import time
801+
s = time.perf_counter()
802+
await step(6)
803+
e = time.perf_counter()
804+
print(e-s)
805+
806+
807+
async def test_taskgroup_fanout_enqueue_future(self):
808+
async def step(i):
809+
if i == 0:
810+
return
811+
async with taskgroups.TaskGroup() as g:
812+
for _ in range(6):
813+
g.enqueue(step(i - 1), no_future=False)
814+
815+
import time
816+
s = time.perf_counter()
817+
await step(6)
818+
e = time.perf_counter()
819+
print(e-s)
725820

726821
if __name__ == "__main__":
727822
unittest.main()

0 commit comments

Comments
 (0)