Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import logging
import math
import os
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING, cast

import tlz as toolz
Expand All @@ -18,6 +20,16 @@

logger = logging.getLogger(__name__)

# Set up file logging for adaptive operations
_slurm_job_id = os.environ.get("SLURM_JOB_ID", "unknown")
_log_file = Path.home() / f"byelayer-dask-log-{_slurm_job_id}"
_file_handler = logging.FileHandler(_log_file)
_file_handler.setLevel(logging.INFO)
_file_handler.setFormatter(
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
)
logger.addHandler(_file_handler)


class AdaptiveCore(ABC):
"""
Expand Down Expand Up @@ -197,6 +209,10 @@ async def adapt(self) -> None:

try:
target = await self.safe_target()
print(
f"Adaptive target: {target}, plan: {len(self.plan)}, requested: {len(self.requested)}, observed: {len(self.observed)}"
)

recommendations = await self.recommendations(target)

if recommendations["status"] != "same":
Expand Down
186 changes: 153 additions & 33 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,55 @@ async def _start(self):
await self._close()
raise RuntimeError(f"Cluster failed to start: {e}") from e

def _spec_name_to_worker_names(self, spec_name):
"""Convert a spec name to the set of worker names it represents.

For regular workers, returns {spec_name} as a string.
For grouped workers, returns {spec_name + suffix for suffix in group}.

Parameters
----------
spec_name : int or str
The spec name (key in worker_spec)

Returns
-------
set of str
The worker names that the scheduler knows about
"""
spec = self.worker_spec.get(spec_name)
if spec and "group" in spec:
return {str(spec_name) + suffix for suffix in spec["group"]}
return {str(spec_name)}

def _worker_name_to_spec_name(self, worker_name):
"""Convert a worker name to its spec name.

For regular workers, returns the worker name (converted to int if numeric).
For grouped workers, extracts the prefix before the first "-".

Parameters
----------
worker_name : str or int
The worker name from the scheduler

Returns
-------
int or str
The spec name (key in worker_spec)
"""
worker_name_str = str(worker_name)
if "-" in worker_name_str:
spec_name = worker_name_str.split("-")[0]
# Convert to int if numeric to match worker_spec keys
if spec_name.isdigit():
return int(spec_name)
return spec_name
# Try to convert to int if numeric
if worker_name_str.isdigit():
return int(worker_name_str)
return worker_name

def _correct_state(self):
if self._correct_state_waiting:
# If people call this frequently, we only want to run it once
Expand All @@ -356,7 +405,29 @@ async def _correct_state_internal(self) -> None:
to_close = set(self.workers) - set(self.worker_spec)
if to_close:
if self.scheduler.status == Status.running:
await self.scheduler_comm.retire_workers(workers=list(to_close))
# For grouped workers, we need to retire the actual worker names
# that the scheduler knows about, not the spec names
actual_workers_to_retire = []
active_worker_names = {
str(w["name"])
for w in self.scheduler_info.get("workers", {}).values()
}

for spec_name in to_close:
# Get all worker names for this spec (handles both regular and grouped)
expected_worker_names = self._spec_name_to_worker_names(
spec_name
)
# Only retire workers that actually exist in the scheduler
for worker_name in expected_worker_names:
if worker_name in active_worker_names:
actual_workers_to_retire.append(worker_name)

if actual_workers_to_retire:
await self.scheduler_comm.retire_workers(
workers=actual_workers_to_retire
)

tasks = [
asyncio.create_task(self.workers[w].close())
for w in to_close
Expand Down Expand Up @@ -397,25 +468,49 @@ async def _correct_state_internal(self) -> None:

def _update_worker_status(self, op, msg):
if op == "remove":
name = self.scheduler_info["workers"][msg]["name"]
# Get worker name - might already be gone from scheduler_info
if msg not in self.scheduler_info.get("workers", {}):
super()._update_worker_status(op, msg)
return

removed_worker_name = self.scheduler_info["workers"][msg]["name"]

# Closure to handle removal of a worker from the cluster
def f():
if (
name in self.workers
and msg not in self.scheduler_info["workers"]
and not any(
d["name"] == name
for d in self.scheduler_info["workers"].values()
# Check if worker is truly gone from scheduler
active_workers = {
d["name"] for d in self.scheduler_info.get("workers", {}).values()
}
if removed_worker_name in active_workers:
return

# Convert worker name to spec name using helper method
worker_spec_name = self._worker_name_to_spec_name(removed_worker_name)

# Check if this is a grouped worker
spec = self.worker_spec.get(worker_spec_name)
is_grouped = spec and "group" in spec

# Close and remove the worker object
if worker_spec_name in self.workers:
self._futures.add(
asyncio.ensure_future(self.workers[worker_spec_name].close())
)
):
self._futures.add(asyncio.ensure_future(self.workers[name].close()))
del self.workers[name]
del self.workers[worker_spec_name]

# Only remove spec for grouped workers
# For grouped workers: when ANY worker dies, the whole group is compromised
# (e.g., in HPC systems, if one process in a multi-process job fails, the
# entire job allocation is typically lost).
# For regular workers: keep spec so cluster can recreate them
if is_grouped and worker_spec_name in self.worker_spec:
del self.worker_spec[worker_spec_name]

delay = parse_timedelta(
dask.config.get("distributed.deploy.lost-worker-timeout")
)

asyncio.get_running_loop().call_later(delay, f)

super()._update_worker_status(op, msg)

def __await__(self: Self) -> Generator[Any, Any, Self]:
Expand Down Expand Up @@ -513,24 +608,57 @@ def _memory_per_worker(self) -> int:
)

def scale(self, n=0, memory=None, cores=None):
if memory is not None:
n = max(n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker())))
# For grouped workers, n represents number of workers, but memory/cores
# calculations represent number of specs (since _memory_per_worker and
# _threads_per_worker return values for the entire MultiWorker spec)
if self.new_spec and "group" in self.new_spec:
workers_per_spec = len(self.new_spec["group"])

# Convert n from number of workers to number of specs
target_specs_from_n = int(math.ceil(n / workers_per_spec)) if n > 0 else 0

# memory/cores calculations already give us number of specs
if memory is not None:
target_specs_from_memory = int(
math.ceil(parse_bytes(memory) / self._memory_per_worker())
)
target_specs = max(target_specs_from_n, target_specs_from_memory)
else:
target_specs = target_specs_from_n

if cores is not None:
target_specs_from_cores = int(
math.ceil(cores / self._threads_per_worker())
)
target_specs = max(target_specs, target_specs_from_cores)
else:
# For regular workers, everything is in terms of workers (which equals specs)
if memory is not None:
n = max(
n, int(math.ceil(parse_bytes(memory) / self._memory_per_worker()))
)

if cores is not None:
n = max(n, int(math.ceil(cores / self._threads_per_worker())))

target_specs = n

if cores is not None:
n = max(n, int(math.ceil(cores / self._threads_per_worker())))
if len(self.worker_spec) > target_specs:
# Build set of spec names that have launched at least one worker
launched_spec_names = set()
for worker_info in self.scheduler_info.get("workers", {}).values():
spec_name = self._worker_name_to_spec_name(worker_info["name"])
launched_spec_names.add(spec_name)

if len(self.worker_spec) > n:
not_yet_launched = set(self.worker_spec) - {
v["name"] for v in self.scheduler_info["workers"].values()
}
while len(self.worker_spec) > n and not_yet_launched:
not_yet_launched = set(self.worker_spec) - launched_spec_names
while len(self.worker_spec) > target_specs and not_yet_launched:
del self.worker_spec[not_yet_launched.pop()]

while len(self.worker_spec) > n:
while len(self.worker_spec) > target_specs:
self.worker_spec.popitem()

if self.status not in (Status.closing, Status.closed):
while len(self.worker_spec) < n:
while len(self.worker_spec) < target_specs:
self.worker_spec.update(self.new_worker_spec())

self.loop.add_callback(self._correct_state)
Expand Down Expand Up @@ -569,17 +697,9 @@ def _supports_scaling(self):
return bool(self.new_spec)

async def scale_down(self, workers):
# We may have groups, if so, map worker addresses to job names
# Convert worker names to spec names (handles both regular and grouped workers)
if not all(w in self.worker_spec for w in workers):
mapping = {}
for name, spec in self.worker_spec.items():
if "group" in spec:
for suffix in spec["group"]:
mapping[str(name) + suffix] = name
else:
mapping[name] = name

workers = {mapping.get(w, w) for w in workers}
workers = {self._worker_name_to_spec_name(w) for w in workers}

for w in workers:
if w in self.worker_spec:
Expand Down
Loading
Loading