Skip to content

Commit

Permalink
Automatically restart P2P shuffles when output worker leaves (#7970)
Browse files Browse the repository at this point in the history
Co-authored-by: Lawrence Mitchell <wence@gmx.li>
  • Loading branch information
hendrikmakait and wence- authored Jul 24, 2023
1 parent 7b0aca7 commit f0303aa
Show file tree
Hide file tree
Showing 5 changed files with 493 additions and 247 deletions.
5 changes: 5 additions & 0 deletions distributed/shuffle/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations


class ShuffleClosedError(RuntimeError):
pass
157 changes: 109 additions & 48 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from itertools import product
from typing import TYPE_CHECKING, Any, ClassVar
Expand Down Expand Up @@ -34,14 +34,15 @@
logger = logging.getLogger(__name__)


@dataclass
@dataclass(eq=False)
class ShuffleState(abc.ABC):
_run_id_iterator: ClassVar[itertools.count] = itertools.count(1)

id: ShuffleId
run_id: int
output_workers: set[str]
participating_workers: set[str]
_archived_by: str | None = field(default=None, init=False)

@abc.abstractmethod
def to_msg(self) -> dict[str, Any]:
Expand All @@ -50,8 +51,11 @@ def to_msg(self) -> dict[str, Any]:
def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"

def __hash__(self) -> int:
return hash(self.run_id)

@dataclass

@dataclass(eq=False)
class DataFrameShuffleState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME
worker_for: dict[int, str]
Expand All @@ -68,7 +72,7 @@ def to_msg(self) -> dict[str, Any]:
}


@dataclass
@dataclass(eq=False)
class ArrayRechunkState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.ARRAY_RECHUNK
worker_for: dict[NDIndex, str]
Expand All @@ -90,19 +94,18 @@ def to_msg(self) -> dict[str, Any]:
class ShuffleSchedulerPlugin(SchedulerPlugin):
"""
Shuffle plugin for the scheduler
This coordinates the individual worker plugins to ensure correctness
and collects heartbeat messages for the dashboard.
See Also
--------
ShuffleWorkerPlugin
"""

scheduler: Scheduler
states: dict[ShuffleId, ShuffleState]
active_shuffles: dict[ShuffleId, ShuffleState]
heartbeats: defaultdict[ShuffleId, dict]
erred_shuffles: dict[ShuffleId, Exception]
_shuffles: defaultdict[ShuffleId, set[ShuffleState]]
_archived_by_stimulus: defaultdict[str, set[ShuffleState]]

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
Expand All @@ -115,9 +118,10 @@ def __init__(self, scheduler: Scheduler):
}
)
self.heartbeats = defaultdict(lambda: defaultdict(dict))
self.states = {}
self.erred_shuffles = {}
self.active_shuffles = {}
self.scheduler.add_plugin(self, name="shuffle")
self._shuffles = defaultdict(set)
self._archived_by_stimulus = defaultdict(set)

async def start(self, scheduler: Scheduler) -> None:
worker_plugin = ShuffleWorkerPlugin()
Expand All @@ -126,18 +130,19 @@ async def start(self, scheduler: Scheduler) -> None:
)

def shuffle_ids(self) -> set[ShuffleId]:
return set(self.states)
return set(self.active_shuffles)

async def barrier(self, id: ShuffleId, run_id: int) -> None:
shuffle = self.states[id]
shuffle = self.active_shuffles[id]
assert shuffle.run_id == run_id, f"{run_id=} does not match {shuffle}"
msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
await self.scheduler.broadcast(
msg=msg, workers=list(shuffle.participating_workers)
msg=msg,
workers=list(shuffle.participating_workers),
)

def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict:
shuffle = self.states[id]
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
return {
"status": "error",
Expand All @@ -158,15 +163,19 @@ def heartbeat(self, ws: WorkerState, data: dict) -> None:
self.heartbeats[shuffle_id][ws.address].update(d)

def get(self, id: ShuffleId, worker: str) -> dict[str, Any]:
if exception := self.erred_shuffles.get(id):
return {"status": "error", "message": str(exception)}
state = self.states[id]
if worker not in self.scheduler.workers:
# This should never happen
raise RuntimeError(
f"Scheduler is unaware of this worker {worker!r}"
) # pragma: nocover
state = self.active_shuffles[id]
state.participating_workers.add(worker)
return state.to_msg()

def get_or_create(
self,
id: ShuffleId,
key: str,
type: str,
worker: str,
spec: dict[str, Any],
Expand All @@ -178,6 +187,7 @@ def get_or_create(
# known by its name. If the name has been mangled, we cannot guarantee
# that the shuffle works as intended and should fail instead.
self._raise_if_barrier_unknown(id)
self._raise_if_task_not_processing(key)

state: ShuffleState
if type == ShuffleType.DATAFRAME:
Expand All @@ -186,7 +196,8 @@ def get_or_create(
state = self._create_array_rechunk_state(id, spec)
else: # pragma: no cover
raise TypeError(type)
self.states[id] = state
self.active_shuffles[id] = state
self._shuffles[id].add(state)
state.participating_workers.add(worker)
return state.to_msg()

Expand All @@ -201,6 +212,11 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
"into this by leaving a comment at distributed#7816."
)

def _raise_if_task_not_processing(self, key: str) -> None:
task = self.scheduler.tasks[key]
if task.state != "processing":
raise RuntimeError(f"Expected {task} to be processing, is {task.state}.")

def _create_dataframe_shuffle_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> DataFrameShuffleState:
Expand Down Expand Up @@ -309,34 +325,67 @@ def _unset_restriction(self, ts: TaskState) -> None:
original_restrictions = ts.annotations.pop("shuffle_original_restrictions")
self.scheduler.set_restrictions({ts.key: original_restrictions})

def _restart_recommendations(self, id: ShuffleId) -> Recs:
barrier_task = self.scheduler.tasks[barrier_key(id)]
recs: Recs = {}

for dt in barrier_task.dependents:
if dt.state == "erred":
return {}
recs.update({dt.key: "released"})

if barrier_task.state == "erred":
# This should never happen, a dependent of the barrier should already
# be `erred`
raise RuntimeError(
f"Expected dependents of {barrier_task=} to be 'erred' if "
"the barrier is."
) # pragma: no cover
recs.update({barrier_task.key: "released"})

for dt in barrier_task.dependencies:
if dt.state == "erred":
# This should never happen, a dependent of the barrier should already
# be `erred`
raise RuntimeError(
f"Expected barrier and its dependents to be "
f"'erred' if the barrier's dependency {dt} is."
) # pragma: no cover
recs.update({dt.key: "released"})
return recs

def _restart_shuffle(
self, id: ShuffleId, scheduler: Scheduler, *, stimulus_id: str
) -> None:
recs = self._restart_recommendations(id)
self.scheduler.transitions(recs, stimulus_id=stimulus_id)
self.scheduler.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def remove_worker(
self, scheduler: Scheduler, worker: str, *, stimulus_id: str, **kwargs: Any
) -> None:
from time import time

stimulus_id = f"shuffle-failed-worker-left-{time()}"
"""Restart all active shuffles when a participating worker leaves the cluster.
.. note::
Due to the order of operations in :meth:`~Scheduler.remove_worker`, the
shuffle may have already been archived by
:meth:`~ShuffleSchedulerPlugin.transition`. In this case, the
``stimulus_id`` is used as a transaction identifier and all archived shuffles
with a matching `stimulus_id` are restarted.
"""

recs: Recs = {}
for shuffle_id, shuffle in self.states.items():
# If processing the transactions causes a task to get released, this
# removes the shuffle from self.active_shuffles. Therefore, we must iterate
# over a copy.
for shuffle_id, shuffle in self.active_shuffles.copy().items():
if worker not in shuffle.participating_workers:
continue
exception = RuntimeError(f"Worker {worker} left during active {shuffle}")
self.erred_shuffles[shuffle_id] = exception
self._fail_on_workers(shuffle, str(exception))
self._clean_on_scheduler(shuffle_id, stimulus_id)

barrier_task = self.scheduler.tasks[barrier_key(shuffle_id)]
if barrier_task.state == "memory":
for dt in barrier_task.dependents:
if worker not in dt.worker_restrictions:
continue
self._unset_restriction(dt)
recs.update({dt.key: "waiting"})
# TODO: Do we need to handle other states?

# If processing the transactions causes a task to get released, this
# removes the shuffle from self.states. Therefore, we must process them
# outside of the loop.
self.scheduler.transitions(recs, stimulus_id=stimulus_id)
for shuffle in self._archived_by_stimulus.get(stimulus_id, set()):
self._restart_shuffle(shuffle.id, scheduler, stimulus_id=stimulus_id)

def transition(
self,
Expand All @@ -347,17 +396,25 @@ def transition(
stimulus_id: str,
**kwargs: Any,
) -> None:
"""Clean up scheduler and worker state once a shuffle becomes inactive."""
if finish not in ("released", "forgotten"):
return
if not key.startswith("shuffle-barrier-"):
return
shuffle_id = id_from_key(key)
try:
shuffle = self.states[shuffle_id]
except KeyError:
return
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten")
self._clean_on_scheduler(shuffle_id)

if shuffle := self.active_shuffles.get(shuffle_id):
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten")
self._clean_on_scheduler(shuffle_id, stimulus_id=stimulus_id)

if finish == "forgotten":
shuffles = self._shuffles.pop(shuffle_id, set())
for shuffle in shuffles:
if shuffle._archived_by:
archived = self._archived_by_stimulus[shuffle._archived_by]
archived.remove(shuffle)
if not archived:
del self._archived_by_stimulus[shuffle._archived_by]

def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None:
worker_msgs = {
Expand All @@ -373,9 +430,12 @@ def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None:
}
self.scheduler.send_all({}, worker_msgs)

def _clean_on_scheduler(self, id: ShuffleId) -> None:
del self.states[id]
self.erred_shuffles.pop(id, None)
def _clean_on_scheduler(self, id: ShuffleId, stimulus_id: str | None) -> None:
shuffle = self.active_shuffles.pop(id)
if not shuffle._archived_by and stimulus_id:
shuffle._archived_by = stimulus_id
self._archived_by_stimulus[stimulus_id].add(shuffle)

with contextlib.suppress(KeyError):
del self.heartbeats[id]

Expand All @@ -384,9 +444,10 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None:
self._unset_restriction(dt)

def restart(self, scheduler: Scheduler) -> None:
self.states.clear()
self.active_shuffles.clear()
self.heartbeats.clear()
self.erred_shuffles.clear()
self._shuffles.clear()
self._archived_by_stimulus.clear()


def get_worker_for_range_sharding(
Expand Down
5 changes: 5 additions & 0 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from distributed.exceptions import Reschedule
from distributed.shuffle._arrow import check_dtype_support, check_minimal_arrow_version
from distributed.shuffle._exceptions import ShuffleClosedError

logger = logging.getLogger("distributed.shuffle")
if TYPE_CHECKING:
Expand Down Expand Up @@ -69,6 +70,8 @@ def shuffle_transfer(
column=column,
parts_out=parts_out,
)
except ShuffleClosedError:
raise Reschedule()
except Exception as e:
raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e

Expand All @@ -82,6 +85,8 @@ def shuffle_unpack(
)
except Reschedule as e:
raise e
except ShuffleClosedError:
raise Reschedule()
except Exception as e:
raise RuntimeError(f"shuffle_unpack failed during shuffle {id}") from e

Expand Down
Loading

0 comments on commit f0303aa

Please sign in to comment.