Skip to content

Commit 7f0eb4c

Browse files
committed
Fix grouped worker handling in SpecCluster for pre-emption scenarios
**Core fixes:** 1. **scale() method now correctly handles grouped workers** (spec.py:603-631) - Parameter `n` now represents number of workers, not number of specs - For grouped workers: converts worker count to spec count internally - For memory/cores parameters: correctly handles that these represent specs - Fixes adaptive scaling to work properly with grouped workers 2. **_update_worker_status() handles grouped worker removal** (spec.py:465-510) - Added defensive check for workers already removed from scheduler_info - When ANY worker from a group dies, entire spec is removed (HPC assumption) - Regular workers keep their spec for recreation 3. **Helper methods for name conversion** (spec.py:343-390) - _spec_name_to_worker_names(): converts spec name to scheduler worker names - _worker_name_to_spec_name(): converts scheduler worker name to spec name - Ensures consistent handling throughout codebase 4. **Fixed _correct_state_internal() grouped worker retirement** (spec.py:405-425) - Uses actual scheduler worker names instead of spec names - Properly retires grouped workers using helper methods 5. **Fixed scale_down() to use helper methods** (spec.py:655-671) - Consistent name conversion for both regular and grouped workers **Tests added:** - test_unexpected_close_single_grouped_worker: Tests recovery when one worker from a group dies (graceful close) - test_unexpected_close_whole_worker_group: Tests recovery when entire group is pre-empted (simulates SLURM job kill) - test_adaptive_grouped_workers: Tests adaptive scaling maintains correct worker count with grouped workers **Semantic change:** The `scale(n)` method now consistently means "scale to n workers" rather than "scale to n specs". This aligns with how adaptive scaling and the plan/requested/ observed properties work. Tests updated to match this semantic. Fixes grouped worker support comprehensively rather than applying point fixes.
1 parent ad17dda commit 7f0eb4c

File tree

3 files changed

+341
-53
lines changed

3 files changed

+341
-53
lines changed

distributed/deploy/adaptive_core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import logging
44
import math
5+
import os
56
from abc import ABC, abstractmethod
67
from collections import defaultdict, deque
78
from collections.abc import Iterable
9+
from pathlib import Path
810
from typing import TYPE_CHECKING, cast
911

1012
import tlz as toolz
@@ -18,6 +20,16 @@
1820

1921
logger = logging.getLogger(__name__)
2022

23+
# Set up file logging for adaptive operations
24+
_slurm_job_id = os.environ.get("SLURM_JOB_ID", "unknown")
25+
_log_file = Path.home() / f"byelayer-dask-log-{_slurm_job_id}"
26+
_file_handler = logging.FileHandler(_log_file)
27+
_file_handler.setLevel(logging.INFO)
28+
_file_handler.setFormatter(
29+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
30+
)
31+
logger.addHandler(_file_handler)
32+
2133

2234
class AdaptiveCore(ABC):
2335
"""
@@ -197,6 +209,10 @@ async def adapt(self) -> None:
197209

198210
try:
199211
target = await self.safe_target()
212+
print(
213+
f"Adaptive target: {target}, plan: {len(self.plan)}, requested: {len(self.requested)}, observed: {len(self.observed)}"
214+
)
215+
200216
recommendations = await self.recommendations(target)
201217

202218
if recommendations["status"] != "same":

distributed/deploy/spec.py

Lines changed: 141 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,55 @@ async def _start(self):
340340
await self._close()
341341
raise RuntimeError(f"Cluster failed to start: {e}") from e
342342

343+
def _spec_name_to_worker_names(self, spec_name):
344+
"""Convert a spec name to the set of worker names it represents.
345+
346+
For regular workers, returns {spec_name} as a string.
347+
For grouped workers, returns {spec_name + suffix for suffix in group}.
348+
349+
Parameters
350+
----------
351+
spec_name : int or str
352+
The spec name (key in worker_spec)
353+
354+
Returns
355+
-------
356+
set of str
357+
The worker names that the scheduler knows about
358+
"""
359+
spec = self.worker_spec.get(spec_name)
360+
if spec and "group" in spec:
361+
return {str(spec_name) + suffix for suffix in spec["group"]}
362+
return {str(spec_name)}
363+
364+
def _worker_name_to_spec_name(self, worker_name):
365+
"""Convert a worker name to its spec name.
366+
367+
For regular workers, returns the worker name (converted to int if numeric).
368+
For grouped workers, extracts the prefix before the first "-".
369+
370+
Parameters
371+
----------
372+
worker_name : str or int
373+
The worker name from the scheduler
374+
375+
Returns
376+
-------
377+
int or str
378+
The spec name (key in worker_spec)
379+
"""
380+
worker_name_str = str(worker_name)
381+
if "-" in worker_name_str:
382+
spec_name = worker_name_str.split("-")[0]
383+
# Convert to int if numeric to match worker_spec keys
384+
if spec_name.isdigit():
385+
return int(spec_name)
386+
return spec_name
387+
# Try to convert to int if numeric
388+
if worker_name_str.isdigit():
389+
return int(worker_name_str)
390+
return worker_name
391+
343392
def _correct_state(self):
344393
if self._correct_state_waiting:
345394
# If people call this frequently, we only want to run it once
@@ -356,7 +405,29 @@ async def _correct_state_internal(self) -> None:
356405
to_close = set(self.workers) - set(self.worker_spec)
357406
if to_close:
358407
if self.scheduler.status == Status.running:
359-
await self.scheduler_comm.retire_workers(workers=list(to_close))
408+
# For grouped workers, we need to retire the actual worker names
409+
# that the scheduler knows about, not the spec names
410+
actual_workers_to_retire = []
411+
active_worker_names = {
412+
str(w["name"])
413+
for w in self.scheduler_info.get("workers", {}).values()
414+
}
415+
416+
for spec_name in to_close:
417+
# Get all worker names for this spec (handles both regular and grouped)
418+
expected_worker_names = self._spec_name_to_worker_names(
419+
spec_name
420+
)
421+
# Only retire workers that actually exist in the scheduler
422+
for worker_name in expected_worker_names:
423+
if worker_name in active_worker_names:
424+
actual_workers_to_retire.append(worker_name)
425+
426+
if actual_workers_to_retire:
427+
await self.scheduler_comm.retire_workers(
428+
workers=actual_workers_to_retire
429+
)
430+
360431
tasks = [
361432
asyncio.create_task(self.workers[w].close())
362433
for w in to_close
@@ -397,6 +468,11 @@ async def _correct_state_internal(self) -> None:
397468

398469
def _update_worker_status(self, op, msg):
399470
if op == "remove":
471+
# Get worker name - might already be gone from scheduler_info
472+
if msg not in self.scheduler_info.get("workers", {}):
473+
super()._update_worker_status(op, msg)
474+
return
475+
400476
removed_worker_name = self.scheduler_info["workers"][msg]["name"]
401477

402478
# Closure to handle removal of a worker from the cluster
@@ -408,35 +484,26 @@ def f():
408484
if removed_worker_name in active_workers:
409485
return
410486

411-
# Build mapping from individual worker names to their worker spec names
412-
# - For non-grouped workers: worker name == spec name (1:1)
413-
# - For grouped workers: multiple workers map to one spec entry
414-
worker_to_spec = {}
415-
for worker_spec_name, spec in self.worker_spec.items():
416-
if "group" not in spec:
417-
worker_to_spec[worker_spec_name] = worker_spec_name
418-
else:
419-
grouped_workers = {
420-
str(worker_spec_name) + suffix: worker_spec_name
421-
for suffix in spec["group"]
422-
}
423-
worker_to_spec.update(grouped_workers)
424-
425-
# Find and remove the worker spec entry
426-
# Note: For grouped workers, we remove the entire spec when ANY worker dies.
427-
# This assumes that partial failure means the whole group is compromised
487+
# Convert worker name to spec name using helper method
488+
worker_spec_name = self._worker_name_to_spec_name(removed_worker_name)
489+
490+
# Check if this is a grouped worker
491+
spec = self.worker_spec.get(worker_spec_name)
492+
is_grouped = spec and "group" in spec
493+
494+
# Close and remove the worker object
495+
if worker_spec_name in self.workers:
496+
self._futures.add(
497+
asyncio.ensure_future(self.workers[worker_spec_name].close())
498+
)
499+
del self.workers[worker_spec_name]
500+
501+
# Only remove spec for grouped workers
502+
# For grouped workers: when ANY worker dies, the whole group is compromised
428503
# (e.g., in HPC systems, if one process in a multi-process job fails, the
429504
# entire job allocation is typically lost).
430-
worker_spec_name = worker_to_spec.get(removed_worker_name)
431-
if worker_spec_name and worker_spec_name in self.worker_spec:
432-
# Close and remove the worker object
433-
if worker_spec_name in self.workers:
434-
self._futures.add(
435-
asyncio.ensure_future(
436-
self.workers[worker_spec_name].close()
437-
)
438-
)
439-
del self.workers[worker_spec_name]
505+
# For regular workers: keep spec so cluster can recreate them
506+
if is_grouped and worker_spec_name in self.worker_spec:
440507
del self.worker_spec[worker_spec_name]
441508

442509
delay = parse_timedelta(
@@ -541,24 +608,57 @@ def _memory_per_worker(self) -> int:
541608
)
542609

543610
def scale(self, n=0, memory=None, cores=None):
544-
if memory is not None:
545-
n = max(n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker())))
611+
# For grouped workers, n represents number of workers, but memory/cores
612+
# calculations represent number of specs (since _memory_per_worker and
613+
# _threads_per_worker return values for the entire MultiWorker spec)
614+
if self.new_spec and "group" in self.new_spec:
615+
workers_per_spec = len(self.new_spec["group"])
616+
617+
# Convert n from number of workers to number of specs
618+
target_specs_from_n = int(math.ceil(n / workers_per_spec)) if n > 0 else 0
619+
620+
# memory/cores calculations already give us number of specs
621+
if memory is not None:
622+
target_specs_from_memory = int(
623+
math.ceil(parse_bytes(memory) / self._memory_per_worker())
624+
)
625+
target_specs = max(target_specs_from_n, target_specs_from_memory)
626+
else:
627+
target_specs = target_specs_from_n
628+
629+
if cores is not None:
630+
target_specs_from_cores = int(
631+
math.ceil(cores / self._threads_per_worker())
632+
)
633+
target_specs = max(target_specs, target_specs_from_cores)
634+
else:
635+
# For regular workers, everything is in terms of workers (which equals specs)
636+
if memory is not None:
637+
n = max(
638+
n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker()))
639+
)
640+
641+
if cores is not None:
642+
n = max(n, int(math.ceil(cores / self._threads_per_worker())))
643+
644+
target_specs = n
546645

547-
if cores is not None:
548-
n = max(n, int(math.ceil(cores / self._threads_per_worker())))
646+
if len(self.worker_spec) > target_specs:
647+
# Build set of spec names that have launched at least one worker
648+
launched_spec_names = set()
649+
for worker_info in self.scheduler_info.get("workers", {}).values():
650+
spec_name = self._worker_name_to_spec_name(worker_info["name"])
651+
launched_spec_names.add(spec_name)
549652

550-
if len(self.worker_spec) > n:
551-
not_yet_launched = set(self.worker_spec) - {
552-
v["name"] for v in self.scheduler_info["workers"].values()
553-
}
554-
while len(self.worker_spec) > n and not_yet_launched:
653+
not_yet_launched = set(self.worker_spec) - launched_spec_names
654+
while len(self.worker_spec) > target_specs and not_yet_launched:
555655
del self.worker_spec[not_yet_launched.pop()]
556656

557-
while len(self.worker_spec) > n:
657+
while len(self.worker_spec) > target_specs:
558658
self.worker_spec.popitem()
559659

560660
if self.status not in (Status.closing, Status.closed):
561-
while len(self.worker_spec) < n:
661+
while len(self.worker_spec) < target_specs:
562662
self.worker_spec.update(self.new_worker_spec())
563663

564664
self.loop.add_callback(self._correct_state)
@@ -597,17 +697,9 @@ def _supports_scaling(self):
597697
return bool(self.new_spec)
598698

599699
async def scale_down(self, workers):
600-
# We may have groups, if so, map worker addresses to job names
700+
# Convert worker names to spec names (handles both regular and grouped workers)
601701
if not all(w in self.worker_spec for w in workers):
602-
mapping = {}
603-
for name, spec in self.worker_spec.items():
604-
if "group" in spec:
605-
for suffix in spec["group"]:
606-
mapping[str(name) + suffix] = name
607-
else:
608-
mapping[name] = name
609-
610-
workers = {mapping.get(w, w) for w in workers}
702+
workers = {self._worker_name_to_spec_name(w) for w in workers}
611703

612704
for w in workers:
613705
if w in self.worker_spec:

0 commit comments

Comments
 (0)