Skip to content

Commit

Permalink
GH-6785: asyncio.wait no longer calls ensure_future
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Nov 30, 2022
1 parent 81f0b67 commit 896f770
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
30 changes: 24 additions & 6 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import logging
import math
import weakref
from collections.abc import Awaitable, Generator
from contextlib import suppress
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar

from tornado import gen
from tornado.ioloop import IOLoop
Expand Down Expand Up @@ -108,6 +109,16 @@ async def __aexit__(self, exc_type, exc_value, traceback):
await self.close()


_T = TypeVar("_T")


async def _wrap_awaitable(aw: Awaitable[_T]) -> _T:
return await aw


_T_spec_cluster = TypeVar("_T_spec_cluster", bound="SpecCluster")


class SpecCluster(Cluster):
"""Cluster that requires a full specification of workers
Expand Down Expand Up @@ -327,7 +338,7 @@ def _correct_state(self):
self._correct_state_waiting = task
return task

async def _correct_state_internal(self):
async def _correct_state_internal(self) -> None:
async with self._lock:
self._correct_state_waiting = None

Expand Down Expand Up @@ -363,7 +374,9 @@ async def _correct_state_internal(self):
self._created.add(worker)
workers.append(worker)
if workers:
await asyncio.wait(workers)
await asyncio.wait(
[asyncio.create_task(_wrap_awaitable(w)) for w in workers]
)
for w in workers:
w._cluster = weakref.ref(self)
await w # for tornado gen.coroutine support
Expand Down Expand Up @@ -392,14 +405,19 @@ def f():
asyncio.get_running_loop().call_later(delay, f)
super()._update_worker_status(op, msg)

def __await__(self):
async def _():
def __await__(self: _T_spec_cluster) -> Generator[Any, Any, _T_spec_cluster]:
async def _() -> _T_spec_cluster:
if self.status == Status.created:
await self._start()
await self.scheduler
await self._correct_state()
if self.workers:
await asyncio.wait(list(self.workers.values())) # maybe there are more
await asyncio.wait(
[
asyncio.create_task(_wrap_awaitable(w))
for w in self.workers.values()
]
) # maybe there are more
return self

return _().__await__()
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2625,7 +2625,7 @@ async def test_task_unique_groups(c, s, a, b):
x = c.submit(sum, [1, 2])
y = c.submit(len, [1, 2])
z = c.submit(sum, [3, 4])
await asyncio.wait([x, y, z])
await asyncio.gather(x, y, z)

assert s.task_prefixes["len"].states["memory"] == 1
assert s.task_prefixes["sum"].states["memory"] == 2
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2772,7 +2772,7 @@ async def test_forget_dependents_after_release(c, s, a):
fut = c.submit(inc, 1, key="f-1")
fut2 = c.submit(inc, fut, key="f-2")

await asyncio.wait([fut, fut2])
await asyncio.gather(fut, fut2)

assert fut.key in a.state.tasks
assert fut2.key in a.state.tasks
Expand Down

0 comments on commit 896f770

Please sign in to comment.