Skip to content

Commit 2286896

Browse files
authored
Do not allow closing workers to be awaited again (#5910)
1 parent 7bd6442 commit 2286896

18 files changed

+270
-121
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ repos:
4545
- types-psutil
4646
- types-setuptools
4747
# Typed libraries
48-
- numpy
4948
- dask
49+
- numpy
50+
- pytest
5051
- tornado
5152
- zict
5253
- pyarrow

distributed/actor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,9 @@ async def run_actor_function_on_worker():
205205
if self._future and not self._future.done():
206206
await self._future
207207
return await run_actor_function_on_worker()
208-
else: # pragma: no cover
209-
raise OSError("Unable to contact Actor's worker")
208+
else:
209+
exc = OSError("Unable to contact Actor's worker")
210+
return _Error(exc)
210211
if result["status"] == "OK":
211212
return _OK(result["result"])
212213
return _Error(result["exception"])

distributed/client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,11 +1315,21 @@ async def _wait_for_workers(self, n_workers=0, timeout=None):
13151315
deadline = time() + parse_timedelta(timeout)
13161316
else:
13171317
deadline = None
1318-
while n_workers and len(info["workers"]) < n_workers:
1318+
1319+
def running_workers(info):
1320+
return len(
1321+
[
1322+
ws
1323+
for ws in info["workers"].values()
1324+
if ws["status"] == Status.running.name
1325+
]
1326+
)
1327+
1328+
while n_workers and running_workers(info) < n_workers:
13191329
if deadline and time() > deadline:
13201330
raise TimeoutError(
13211331
"Only %d/%d workers arrived after %s"
1322-
% (len(info["workers"]), n_workers, timeout)
1332+
% (running_workers(info), n_workers, timeout)
13231333
)
13241334
await asyncio.sleep(0.1)
13251335
info = await self.scheduler.identity()

distributed/core.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from distributed.metrics import time
3939
from distributed.system_monitor import SystemMonitor
4040
from distributed.utils import (
41-
TimeoutError,
4241
get_traceback,
4342
has_keyword,
4443
is_coroutine_function,
@@ -71,11 +70,6 @@ class Status(Enum):
7170

7271

7372
Status.lookup = {s.name: s for s in Status} # type: ignore
74-
Status.ANY_RUNNING = { # type: ignore
75-
Status.running,
76-
Status.paused,
77-
Status.closing_gracefully,
78-
}
7973

8074

8175
class RPCClosed(IOError):
@@ -168,6 +162,7 @@ def __init__(
168162
timeout=None,
169163
io_loop=None,
170164
):
165+
self._status = Status.init
171166
self.handlers = {
172167
"identity": self.identity,
173168
"echo": self.echo,
@@ -257,7 +252,8 @@ def set_thread_ident():
257252

258253
self.io_loop.add_callback(set_thread_ident)
259254
self._startup_lock = asyncio.Lock()
260-
self.status = Status.undefined
255+
self.__startup_exc: Exception | None = None
256+
self.__started = asyncio.Event()
261257

262258
self.rpc = ConnectionPool(
263259
limit=connection_limit,
@@ -289,31 +285,48 @@ async def finished(self):
289285
await self._event_finished.wait()
290286

291287
def __await__(self):
292-
async def _():
293-
timeout = getattr(self, "death_timeout", 0)
294-
async with self._startup_lock:
295-
if self.status in Status.ANY_RUNNING:
296-
return self
297-
if timeout:
298-
try:
299-
await asyncio.wait_for(self.start(), timeout=timeout)
300-
self.status = Status.running
301-
except Exception:
302-
await self.close(timeout=1)
303-
raise TimeoutError(
304-
"{} failed to start in {} seconds".format(
305-
type(self).__name__, timeout
306-
)
307-
)
308-
else:
309-
await self.start()
310-
self.status = Status.running
311-
return self
288+
return self.start().__await__()
312289

313-
return _().__await__()
290+
async def start_unsafe(self):
291+
"""Attempt to start the server. This is not idempotent and not protected against concurrent startup attempts.
314292
315-
async def start(self):
293+
This is intended to be overwritten or called by subclasses. For a safe
294+
startup, please use ``Server.start`` instead.
295+
296+
If ``death_timeout`` is configured, we will require this coroutine to
297+
finish before this timeout is reached. If the timeout is reached we will
298+
close the instance and raise an ``asyncio.TimeoutError``
299+
"""
316300
await self.rpc.start()
301+
return self
302+
303+
async def start(self):
304+
async with self._startup_lock:
305+
if self.status == Status.failed:
306+
assert self.__startup_exc is not None
307+
raise self.__startup_exc
308+
elif self.status != Status.init:
309+
return self
310+
timeout = getattr(self, "death_timeout", None)
311+
312+
async def _close_on_failure(exc: Exception):
313+
await self.close()
314+
self.status = Status.failed
315+
self.__startup_exc = exc
316+
317+
try:
318+
await asyncio.wait_for(self.start_unsafe(), timeout=timeout)
319+
except asyncio.TimeoutError as exc:
320+
await _close_on_failure(exc)
321+
raise asyncio.TimeoutError(
322+
f"{type(self).__name__} start timed out after {timeout}s."
323+
) from exc
324+
except Exception as exc:
325+
await _close_on_failure(exc)
326+
raise RuntimeError(f"{type(self).__name__} failed to start.") from exc
327+
self.status = Status.running
328+
self.__started.set()
329+
return self
317330

318331
async def __aenter__(self):
319332
await self
@@ -382,16 +395,28 @@ def _cycle_ticks(self):
382395
self._tick_interval_observed = (time() - last) / (count or 1)
383396

384397
@property
385-
def address(self):
398+
def address(self) -> str:
386399
"""
387400
The address this Server can be contacted on.
401+
If the server is not up, yet, this raises a ValueError.
388402
"""
389403
if not self._address:
390404
if self.listener is None:
391405
raise ValueError("cannot get address of non-running Server")
392406
self._address = self.listener.contact_address
393407
return self._address
394408

409+
@property
410+
def address_safe(self) -> str:
411+
"""
412+
The address this Server can be contacted on.
413+
If the server is not up, yet, this returns a ``"not-running"``.
414+
"""
415+
try:
416+
return self.address
417+
except ValueError:
418+
return "not-running"
419+
395420
@property
396421
def listen_address(self):
397422
"""
@@ -480,6 +505,7 @@ async def handle_comm(self, comm):
480505

481506
logger.debug("Connection from %r to %s", address, type(self).__name__)
482507
self._comms[comm] = op
508+
483509
await self
484510
try:
485511
while True:
@@ -650,11 +676,13 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
650676
def close(self):
651677
for pc in self.periodic_callbacks.values():
652678
pc.stop()
653-
self.__stopped = True
654-
for listener in self.listeners:
655-
future = listener.stop()
656-
if inspect.isawaitable(future):
657-
yield future
679+
680+
if not self.__stopped:
681+
self.__stopped = True
682+
for listener in self.listeners:
683+
future = listener.stop()
684+
if inspect.isawaitable(future):
685+
yield future
658686
for i in range(20):
659687
# If there are still handlers running at this point, give them a
660688
# second to finish gracefully themselves, otherwise...

distributed/deploy/spec.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,7 @@ async def _correct_state_internal(self):
325325
for w in to_close
326326
if w in self.workers
327327
]
328-
await asyncio.wait(tasks)
329-
for task in tasks: # for tornado gen.coroutine support
330-
with suppress(RuntimeError):
331-
await task
328+
await asyncio.gather(*tasks)
332329
for name in to_close:
333330
if name in self.workers:
334331
del self.workers[name]
@@ -417,7 +414,7 @@ async def _close(self):
417414

418415
await self.scheduler.close()
419416
for w in self._created:
420-
assert w.status == Status.closed, w.status
417+
assert w.status in {Status.closed, Status.failed}, w.status
421418

422419
if hasattr(self, "_old_logging_level"):
423420
silence_logging(self._old_logging_level)

distributed/nanny.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ class Nanny(ServerNode):
9494

9595
_instances: ClassVar[weakref.WeakSet[Nanny]] = weakref.WeakSet()
9696
process = None
97-
status = Status.undefined
9897
memory_manager: NannyMemoryManager
9998

10099
# Inputs to parse_ports()
@@ -269,7 +268,6 @@ def __init__(
269268

270269
self._listen_address = listen_address
271270
Nanny._instances.add(self)
272-
self.status = Status.init
273271

274272
# Deprecated attributes; use Nanny.memory_manager.<name> instead
275273
memory_limit = DeprecatedMemoryManagerAttribute()
@@ -309,10 +307,10 @@ def local_dir(self):
309307
warnings.warn("The local_dir attribute has moved to local_directory")
310308
return self.local_directory
311309

312-
async def start(self):
310+
async def start_unsafe(self):
313311
"""Start nanny, start local process, start watching"""
314312

315-
await super().start()
313+
await super().start_unsafe()
316314

317315
ports = parse_ports(self._start_port)
318316
for port in ports:
@@ -337,7 +335,7 @@ async def start(self):
337335
break
338336
else:
339337
raise ValueError(
340-
f"Could not start Nanny on host {self._start_host}"
338+
f"Could not start Nanny on host {self._start_host} "
341339
f"with port {self._start_port}"
342340
)
343341

@@ -352,11 +350,12 @@ async def start(self):
352350

353351
logger.info(" Start Nanny at: %r", self.address)
354352
response = await self.instantiate()
355-
if response == Status.running:
356-
assert self.worker_address
357-
self.status = Status.running
358-
else:
353+
354+
if response != Status.running:
359355
await self.close()
356+
return
357+
358+
assert self.worker_address
360359

361360
self.start_periodic_callbacks()
362361

@@ -571,7 +570,9 @@ async def close(self, comm=None, timeout=5, report=None):
571570

572571
self.status = Status.closing
573572
logger.info(
574-
f"Closing Nanny at {self.address!r}. Report closure to scheduler: {report}"
573+
"Closing Nanny at %r. Report closure to scheduler: %s",
574+
self.address_safe,
575+
report,
575576
)
576577

577578
for preload in self.preloads:

distributed/scheduler.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def identity(self) -> dict[str, Any]:
584584
"last_seen": self.last_seen,
585585
"services": self.services,
586586
"metrics": self.metrics,
587+
"status": self.status.name,
587588
"nanny": self.nanny,
588589
**self.extra,
589590
}
@@ -3235,15 +3236,14 @@ def __init__(
32353236
setproctitle("dask-scheduler [not started]")
32363237
Scheduler._instances.add(self)
32373238
self.rpc.allow_offload = False
3238-
self.status = Status.undefined
32393239

32403240
##################
32413241
# Administration #
32423242
##################
32433243

32443244
def __repr__(self):
32453245
return (
3246-
f"<Scheduler {self.address!r}, "
3246+
f"<Scheduler {self.address_safe!r}, "
32473247
f"workers: {len(self.workers)}, "
32483248
f"cores: {self.total_nthreads}, "
32493249
f"tasks: {len(self.tasks)}>"
@@ -3376,10 +3376,9 @@ def get_worker_service_addr(
33763376
else:
33773377
return ws.host, port
33783378

3379-
async def start(self):
3379+
async def start_unsafe(self):
33803380
"""Clear out old state and restart all running coroutines"""
3381-
await super().start()
3382-
assert self.status != Status.running
3381+
await super().start_unsafe()
33833382

33843383
enable_gc_diagnosis()
33853384

distributed/tests/test_actor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ async def test_failed_worker(c, s, a, b):
290290

291291
assert "actor" in str(info.value).lower()
292292
assert "worker" in str(info.value).lower()
293-
assert "lost" in str(info.value).lower()
294293

295294

296295
@gen_cluster(client=True)

0 commit comments

Comments
 (0)