Skip to content

Commit 78b5e0b

Browse files
committed
eager evaluate in enqueue
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 71962ce commit 78b5e0b

File tree

2 files changed

+26
-51
lines changed

2 files changed

+26
-51
lines changed

Lib/asyncio/taskgroups.py

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(self):
3030
self._errors = []
3131
self._base_error = None
3232
self._on_completed_fut = None
33-
self._enqueues = {}
3433

3534
def __repr__(self):
3635
info = ['']
@@ -64,16 +63,6 @@ async def __aenter__(self):
6463

6564
return self
6665

67-
def eager_eval(self, coro):
68-
try:
69-
fut = coro.send(None)
70-
except StopIteration as e:
71-
return PyCoroEagerResult(e.args[0] if e.args else None)
72-
else:
73-
task = self.create_task(coro)
74-
task._set_fut_awaiter(fut)
75-
return task
76-
7766
async def __aexit__(self, et, exc, tb):
7867
self._exiting = True
7968
propagate_cancellation_error = None
@@ -106,23 +95,6 @@ async def __aexit__(self, et, exc, tb):
10695
#
10796
self._abort()
10897

109-
if self._enqueues:
110-
for coro in self._enqueues:
111-
res = self.eager_eval(coro)
112-
if isinstance(res, PyCoroEagerResult):
113-
fut = self._enqueues[coro]
114-
if fut is not None:
115-
fut.set_result(res.value)
116-
else:
117-
def queue_callback():
118-
fut = self._enqueues[coro]
119-
if fut is not None:
120-
res.add_done_callback(lambda task: fut.set_result(task.result()))
121-
queue_callback()
122-
123-
self._unfinished_tasks -= len(self._enqueues)
124-
self._enqueues.clear()
125-
12698
# We use while-loop here because "self._on_completed_fut"
12799
# can be cancelled multiple times if our parent task
128100
# is being cancelled repeatedly (or even once, when
@@ -187,30 +159,33 @@ def create_task(self, coro, *, name=None, context=None):
187159
self._tasks.add(task)
188160
return task
189161

162+
def _eager_eval(self, coro):
163+
try:
164+
fut = coro.send(None)
165+
task = self.create_task(coro)
166+
task._set_fut_awaiter(fut)
167+
return task
168+
except StopIteration as e:
169+
# The co-routine has completed synchronously and we've got
170+
# our result.
171+
return PyCoroEagerResult(e.args[0] if e.args else None)
172+
except Exception as e:
173+
res = Future(loop=self._loop)
174+
res.set_exception(e)
175+
return res
176+
190177
def enqueue(self, coro, no_future=True):
191178
if not self._entered:
192179
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
193180

194-
if coro in self._enqueues:
195-
return self._enqueues[coro]
196-
197-
if not self._enqueues:
198-
# if the looop starts running because someone awaits, we want
199-
# to run the co-routines which are enqueued as well.
200-
self._loop.call_soon(self._enqueue_enqueus)
201-
202-
self._unfinished_tasks += 1
203-
if no_future:
204-
fut = self._enqueues[coro] = None
181+
res = self._eager_eval(coro)
182+
if isinstance(res, PyCoroEagerResult):
183+
if not no_future:
184+
fut = Future(loop=self._loop)
185+
fut.set_result(res.value)
186+
return fut
205187
else:
206-
fut = self._enqueues[coro] = Future(loop=self._loop)
207-
return fut
208-
209-
def _enqueue_enqueus(self):
210-
for coro in self._enqueues:
211-
self.create_task(coro)
212-
self._unfinished_tasks -= len(self._enqueues)
213-
self._enqueues.clear()
188+
return res
214189

215190
# Since Python 3.8 Tasks propagate all exceptions correctly,
216191
# except for KeyboardInterrupt and SystemExit which are

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -740,15 +740,15 @@ async def eager():
740740

741741
async def test_taskgroup_enqueue_02(self):
742742

743-
async def foo1():
743+
async def eager1():
744744
return 42
745745

746-
async def eager():
746+
async def eager2():
747747
return 11
748748

749749
async with taskgroups.TaskGroup() as g:
750-
t1 = g.enqueue(foo1(), no_future=False)
751-
t2 = g.enqueue(eager(), no_future=False)
750+
t1 = g.enqueue(eager1(), no_future=False)
751+
t2 = g.enqueue(eager2(), no_future=False)
752752

753753
self.assertEqual(t1.result(), 42)
754754
self.assertEqual(t2.result(), 11)

0 commit comments

Comments
 (0)