Skip to content

Enable Client.wait_for_workers to optionally also wait when removing workers #6377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from contextvars import ContextVar
from functools import partial
from numbers import Number
from operator import gt, lt, ne
from queue import Queue as pyQueue
from typing import Any, ClassVar, Coroutine, Literal, Sequence, TypedDict

Expand Down Expand Up @@ -117,6 +118,9 @@
"pubsub": PubSubClientExtension,
}

# Mode to use when waiting for workers.
WORKER_WAIT_MODE = Literal["at least", "at most", "exactly"]


def _get_global_client() -> Client | None:
L = sorted(list(_global_clients), reverse=True)
Expand Down Expand Up @@ -1329,34 +1333,50 @@ async def _update_scheduler_info(self):
except OSError:
logger.debug("Not able to query scheduler for identity")

async def _wait_for_workers(self, n_workers=0, timeout=None):
async def _wait_for_workers(
self,
n_workers: int = 0,
timeout: int | None = None,
mode: WORKER_WAIT_MODE = "at least",
):
info = await self.scheduler.identity()
self._scheduler_identity = SchedulerInfo(info)
if timeout:
deadline = time() + parse_timedelta(timeout)
else:
deadline = None

def running_workers(info):
def running_workers(info, status_list=[Status.running]):
return len(
[
ws
for ws in info["workers"].values()
if ws["status"] == Status.running.name
if ws["status"] in [s.name for s in status_list]
]
)

while n_workers and running_workers(info) < n_workers:
try:
op, required_status = {
"at least": (lt, [Status.running]),
"exactly": (ne, [Status.running, Status.paused]),
"at most": (gt, [Status.running, Status.paused]),
}[mode]
except KeyError:
raise NotImplementedError(f"{mode} is not handled.")

while op(running_workers(info, status_list=required_status), n_workers):
if deadline and time() > deadline:
raise TimeoutError(
"Only %d/%d workers arrived after %s"
% (running_workers(info), n_workers, timeout)
"Had %d workers after %s and needed %s %d"
% (running_workers(info), timeout, mode, n_workers)
)
await asyncio.sleep(0.1)
info = await self.scheduler.identity()
self._scheduler_identity = SchedulerInfo(info)

def wait_for_workers(self, n_workers=0, timeout=None):
def wait_for_workers(
self, n_workers=0, timeout=None, mode: WORKER_WAIT_MODE = "at least"
):
"""Blocking call to wait for n workers before continuing

Parameters
Expand All @@ -1366,8 +1386,12 @@ def wait_for_workers(self, n_workers=0, timeout=None):
timeout : number, optional
Time in seconds after which to raise a
``dask.distributed.TimeoutError``
mode : "at least" | "at most" | "exactly", optional
Mode to use when waiting for workers.
Default ``"at least"``, waits for at least ``n_workers``.
One can also specify waiting for ``"at most"`` or ``"exactly"`` ``n_workers``.
"""
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)
return self.sync(self._wait_for_workers, n_workers, timeout=timeout, mode=mode)

def _heartbeat(self):
if self.scheduler_comm:
Expand Down
61 changes: 51 additions & 10 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6115,21 +6115,62 @@ async def test_instances(c, s, a, b):

@gen_cluster(client=True)
async def test_wait_for_workers(c, s, a, b):
future = asyncio.ensure_future(c.wait_for_workers(n_workers=3))
await c.wait_for_workers(n_workers=1)

future = asyncio.create_task(c.wait_for_workers(n_workers=3))
await asyncio.sleep(0.22) # 2 chances
assert not future.done()

w = await Worker(s.address)
start = time()
await future
assert time() < start + 1
await w.close()
async with Worker(s.address):
start = time()
await future
assert time() < start + 1

with pytest.raises(
TimeoutError, match="3 workers after 1 ms and needed at least 10"
) as info:
await c.wait_for_workers(n_workers=10, timeout="1 ms")

future = asyncio.create_task(c.wait_for_workers(n_workers=2))
await asyncio.sleep(0.22) # 2 chances
assert future.done()


@gen_cluster(client=True)
async def test_wait_for_workers_max(c, s, a, b):
with pytest.raises(TimeoutError, match="2 workers after 1 ms and needed at most 1"):
await c.wait_for_workers(n_workers=1, mode="at most", timeout="1 ms")

t = asyncio.create_task(c.wait_for_workers(n_workers=1, mode="at most"))
await asyncio.sleep(0.5)
assert not t.done()
await b.close()
await t

with pytest.raises(TimeoutError) as info:
await c.wait_for_workers(n_workers=10, timeout="1 ms")
# already at target size; should be instant
await c.wait_for_workers(n_workers=1, mode="at most", timeout="1s")
await c.wait_for_workers(n_workers=2, mode="at most", timeout="1s")


@gen_cluster(client=True)
async def test_wait_for_workers_exactly(c, s, a, b):
with pytest.raises(TimeoutError, match="2 workers after 1 ms and needed exactly 1"):
await c.wait_for_workers(n_workers=1, mode="exactly", timeout="1 ms")

assert "2/10" in str(info.value).replace(" ", "")
assert "1 ms" in str(info.value)
t = asyncio.create_task(c.wait_for_workers(n_workers=1, mode="exactly"))
await asyncio.sleep(0.5)
assert not t.done()
await b.close()
await t

# already at target size; should be instant
await c.wait_for_workers(n_workers=1, mode="exactly", timeout="1s")


@gen_cluster(client=True)
async def test_wait_for_workers_bad_mode(c, s, a, b):
with pytest.raises(NotImplementedError):
await c.wait_for_workers(n_workers=1, timeout="1 ms", mode="foo")


@pytest.mark.skipif(WINDOWS, reason="num_fds not supported on windows")
Expand Down