diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index 6a1e24f9b4..b810c84d9e 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -341,7 +341,7 @@ def __init__(self, worker: Worker) -> None: worker.handlers["shuffle_receive"] = self.shuffle_receive worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done worker.handlers["shuffle_fail"] = self.shuffle_fail - worker.stream_handlers["shuffle-forget"] = self.shuffle_forget + worker.stream_handlers["shuffle-fail"] = self.shuffle_fail worker.extensions["shuffle"] = self # Initialize @@ -397,9 +397,6 @@ async def shuffle_fail(self, shuffle_id: ShuffleId, message: str) -> None: await shuffle.close() del self.shuffles[shuffle_id] - async def shuffle_forget(self, shuffle_id: ShuffleId) -> None: - await self.shuffle_fail(shuffle_id, message="Shuffle {shuffle_id} forgotten") - def add_partition( self, data: pd.DataFrame, @@ -779,7 +776,13 @@ def transition( shuffle_id = ShuffleSchedulerExtension.id_from_key(key) participating_workers = self.participating_workers[shuffle_id] worker_msgs = { - worker: [{"op": "shuffle-forget", "shuffle_id": shuffle_id}] + worker: [ + { + "op": "shuffle-fail", + "shuffle_id": shuffle_id, + "message": f"Shuffle {shuffle_id} forgotten", + } + ] for worker in participating_workers } self._clean_on_scheduler(shuffle_id)