-
-
Notifications
You must be signed in to change notification settings - Fork 717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fail P2PShuffle gracefully upon worker failure #7326
Changes from 16 commits
46860d5
43b4cf7
c86bc04
f8b59d6
5c24302
f6a2248
cac380d
8c07589
012c690
7634a60
1af605e
d7db4ba
7fdfe60
44452c5
2e1cf51
184823c
c7b84c4
a6a9445
59bc3b1
acfbf30
8e97c7c
5f92727
b36862f
bc16a47
40e5b4b
646c721
da6f160
f9c4db3
6e34ac3
871ddb7
fea0c7d
0ed9f61
3a50c7e
fd37f3b
f13cead
c567651
987e3a3
3281152
e48fe17
bed5c98
23408c3
3339026
49e3a81
e83aceb
f1c1478
426c4fc
c851887
f81fff0
79c2834
57ccc17
e842e02
e0482f1
f6efc70
96d6aed
5cb3be5
8462d93
a0a6881
608dea5
307023c
cf3fce4
7a8f24d
2f5e676
720b841
e392b32
c9dc954
a3229f7
e31b566
c6b1d30
4865cbe
588306f
b44ebe2
550dada
7a186c9
91516ac
9a28675
3c09d8f
7d7ec2f
35ca74b
deaa9a4
06c0cdc
5b9ea61
fd7451b
cec76b9
8a447d5
d3f8fe8
6e4273a
e4de791
a33ce0e
3163b3b
f8c4adb
7d54aac
5f1a41f
3c90f50
9040137
bc317e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
from dask.utils import parse_bytes | ||
|
||
from distributed.core import PooledRPCCall | ||
from distributed.diagnostics.plugin import SchedulerPlugin | ||
from distributed.protocol import to_serialize | ||
from distributed.shuffle._arrow import ( | ||
deserialize_schema, | ||
|
@@ -40,6 +41,10 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ShuffleClosedError(RuntimeError): | ||
pass | ||
|
||
|
||
class Shuffle: | ||
"""State for a single active shuffle | ||
|
||
|
@@ -115,6 +120,7 @@ def __init__( | |
partitions_of[addr].append(part) | ||
self.partitions_of = dict(partitions_of) | ||
self.worker_for = pd.Series(worker_for, name="_workers").astype("category") | ||
self.closed = False | ||
|
||
def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: | ||
return dump_batch(batch, file, self.schema) | ||
|
@@ -138,6 +144,7 @@ def _dump_batch(batch: pa.Buffer, file: BinaryIO) -> None: | |
self.total_recvd = 0 | ||
self.start_time = time.time() | ||
self._exception: Exception | None = None | ||
self._close_lock = asyncio.Lock() | ||
|
||
def __repr__(self) -> str: | ||
return f"<Shuffle id: {self.id} on {self.local_address}>" | ||
|
@@ -150,6 +157,7 @@ def time(self, name: str) -> Iterator[None]: | |
self.diagnostics[name] += stop - start | ||
|
||
async def barrier(self) -> None: | ||
self.raise_if_closed() | ||
# FIXME: This should restrict communication to only workers | ||
# participating in this specific shuffle. This will not only reduce the | ||
# number of workers we need to contact but will also simplify error | ||
|
@@ -173,6 +181,7 @@ async def send(self, address: str, shards: list[bytes]) -> None: | |
) | ||
|
||
async def offload(self, func: Callable[..., T], *args: Any) -> T: | ||
self.raise_if_closed() | ||
with self.time("cpu"): | ||
return await asyncio.get_running_loop().run_in_executor( | ||
self.executor, | ||
|
@@ -194,8 +203,7 @@ async def receive(self, data: list[bytes]) -> None: | |
await self._receive(data) | ||
|
||
async def _receive(self, data: list[bytes]) -> None: | ||
if self._exception: | ||
raise self._exception | ||
self.raise_if_closed() | ||
|
||
try: | ||
self.total_recvd += sum(map(len, data)) | ||
|
@@ -219,11 +227,20 @@ async def _receive(self, data: list[bytes]) -> None: | |
for k, v in groups.items() | ||
} | ||
) | ||
self.raise_if_closed() | ||
await self._disk_buffer.write(groups) | ||
except Exception as e: | ||
self._exception = e | ||
raise | ||
|
||
def raise_if_closed(self) -> None: | ||
if self.closed: | ||
if self._exception: | ||
raise self._exception | ||
raise ShuffleClosedError( | ||
f"Shuffle {self.id} has been closed on {self.local_address}" | ||
) | ||
|
||
async def add_partition(self, data: pd.DataFrame) -> None: | ||
if self.transferred: | ||
raise RuntimeError(f"Cannot add more partitions to shuffle {self}") | ||
|
@@ -244,6 +261,7 @@ def _() -> dict[str, list[bytes]]: | |
await self._comm_buffer.write(out) | ||
|
||
async def get_output_partition(self, i: int) -> pd.DataFrame: | ||
self.raise_if_closed() | ||
assert self.transferred, "`get_output_partition` called before barrier task" | ||
|
||
assert self.worker_for[i] == self.local_address, ( | ||
|
@@ -258,6 +276,7 @@ async def get_output_partition(self, i: int) -> pd.DataFrame: | |
), f"No outputs remaining, but requested output partition {i} on {self.local_address}." | ||
await self.flush_receive() | ||
try: | ||
self.raise_if_closed() | ||
df = self._disk_buffer.read(i) | ||
with self.time("cpu"): | ||
out = df.to_pandas() | ||
|
@@ -269,6 +288,7 @@ async def get_output_partition(self, i: int) -> pd.DataFrame: | |
async def inputs_done(self) -> None: | ||
assert not self.transferred, "`inputs_done` called multiple times" | ||
self.transferred = True | ||
self.raise_if_closed() | ||
await self._comm_buffer.flush() | ||
try: | ||
self._comm_buffer.raise_on_exception() | ||
|
@@ -280,17 +300,23 @@ def done(self) -> bool: | |
return self.transferred and self.output_partitions_left == 0 | ||
|
||
async def flush_receive(self) -> None: | ||
if self._exception: | ||
raise self._exception | ||
self.raise_if_closed() | ||
await self._disk_buffer.flush() | ||
|
||
async def close(self) -> None: | ||
await self._comm_buffer.close() | ||
await self._disk_buffer.close() | ||
try: | ||
self.executor.shutdown(cancel_futures=True) | ||
except Exception: | ||
self.executor.shutdown() | ||
self.closed = True | ||
async with self._close_lock: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How come this is overlapping? why do we need a lock? Are we sure that everything we do below is idempotent? If we do not want to guarantee idempotency, the pattern we take in the server classes might be better suited than a lock, i.e. async def close(self) -> None:
if self.closed:
await self._event_close.wait()
self.closed = True
await close_all_stuff()
self._event_close.set() this locks + makes it idempotent even without relying on the buffers/executors / whatever else to come to be. The only important thing is that nobody must reset the closed attribute. This is a one way street, otherwise this pattern breaks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fair point about idempotency. While everything should be idempotent at the moment, I'll adjust this to use the more cautious pattern of waiting on an event. Since shuffles should never reopen once closed, this should be fine. |
||
await self._comm_buffer.close() | ||
await self._disk_buffer.close() | ||
try: | ||
self.executor.shutdown(cancel_futures=True) | ||
except Exception: | ||
self.executor.shutdown() | ||
|
||
async def fail(self, exception: Exception) -> None: | ||
if not self.closed: | ||
self._exception = exception | ||
await self.close() | ||
|
||
|
||
class ShuffleWorkerExtension: | ||
|
@@ -305,17 +331,27 @@ class ShuffleWorkerExtension: | |
- collecting instrumentation of ongoing shuffles and route to scheduler/worker | ||
""" | ||
|
||
worker: Worker | ||
shuffles: dict[ShuffleId, Shuffle] | ||
erred_shuffles: dict[ShuffleId, Exception] | ||
memory_limiter_comms: ResourceLimiter | ||
memory_limiter_disk: ResourceLimiter | ||
closed: bool | ||
|
||
def __init__(self, worker: Worker) -> None: | ||
# Attach to worker | ||
worker.handlers["shuffle_receive"] = self.shuffle_receive | ||
worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done | ||
worker.handlers["shuffle_fail"] = self.shuffle_fail | ||
worker.extensions["shuffle"] = self | ||
|
||
# Initialize | ||
self.worker: Worker = worker | ||
self.shuffles: dict[ShuffleId, Shuffle] = {} | ||
self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) | ||
self.worker = worker | ||
self.shuffles = {} | ||
self.erred_shuffles = {} | ||
self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) | ||
self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) | ||
self.closed = False | ||
|
||
# Handlers | ||
########## | ||
|
@@ -333,8 +369,13 @@ async def shuffle_receive( | |
Handler: Receive an incoming shard of data from a peer worker. | ||
Using an unknown ``shuffle_id`` is an error. | ||
""" | ||
shuffle = await self._get_shuffle(shuffle_id) | ||
await shuffle.receive(data) | ||
try: | ||
shuffle = await self._get_shuffle(shuffle_id) | ||
await shuffle.receive(data) | ||
except ShuffleClosedError: | ||
from distributed.worker import Reschedule | ||
|
||
raise Reschedule() | ||
hendrikmakait marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
async def shuffle_inputs_done(self, shuffle_id: ShuffleId) -> None: | ||
""" | ||
|
@@ -353,6 +394,12 @@ async def shuffle_inputs_done(self, shuffle_id: ShuffleId) -> None: | |
logger.critical(f"Shuffle inputs done {shuffle}") | ||
await self._register_complete(shuffle) | ||
|
||
async def shuffle_fail(self, shuffle_id: ShuffleId, message: str) -> None: | ||
shuffle = self.shuffles.pop(shuffle_id) | ||
exception = RuntimeError(message) | ||
self.erred_shuffles[shuffle_id] = exception | ||
await shuffle.fail(exception) | ||
|
||
def add_partition( | ||
self, | ||
data: pd.DataFrame, | ||
|
@@ -378,6 +425,7 @@ async def _barrier(self, shuffle_id: ShuffleId) -> None: | |
await shuffle.barrier() | ||
|
||
async def _register_complete(self, shuffle: Shuffle) -> None: | ||
self.raise_if_closed() | ||
await shuffle.close() | ||
await self.worker.scheduler.shuffle_register_complete( | ||
id=shuffle.id, | ||
|
@@ -411,6 +459,10 @@ async def _get_shuffle( | |
"Get a shuffle by ID; raise ValueError if it's not registered." | ||
import pyarrow as pa | ||
|
||
self.raise_if_closed() | ||
|
||
if exception := self.erred_shuffles.get(shuffle_id): | ||
raise exception | ||
try: | ||
return self.shuffles[shuffle_id] | ||
except KeyError: | ||
|
@@ -423,6 +475,11 @@ async def _get_shuffle( | |
npartitions=npartitions, | ||
column=column, | ||
) | ||
if result["status"] == "ERROR": | ||
raise RuntimeError( | ||
f"Worker {result['worker']} left during active shuffle {shuffle_id}" | ||
) | ||
assert result["status"] == "OK" | ||
except KeyError: | ||
# Even the scheduler doesn't know about this shuffle | ||
# Let's hand this back to the scheduler and let it figure | ||
|
@@ -434,6 +491,7 @@ async def _get_shuffle( | |
|
||
raise Reschedule() | ||
else: | ||
self.raise_if_closed() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels a bit odd to fail here. I'd expect this to happen somewhere else (and all tests pass if I remove this). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, we technically need this here. Otherwise we can run into a very unlikely edge case where the extension has been closed since we started sending the RPC to the scheduler. This would leave the new |
||
if shuffle_id not in self.shuffles: | ||
shuffle = Shuffle( | ||
Comment on lines
491
to
492
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just discussed that intstatiating a new shuffle instance (and also talking to the scheduler itself) is an error case that can only be caught by implementing the raise_if_closed pattern on extension level. everything else could be handled on shuffle instance level instead |
||
column=result["column"], | ||
|
@@ -455,10 +513,17 @@ async def _get_shuffle( | |
return self.shuffles[shuffle_id] | ||
|
||
async def close(self) -> None: | ||
self.closed = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This attribute indicates There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I've been thinking whether it would make sense to replace the boolean with a state |
||
while self.shuffles: | ||
_, shuffle = self.shuffles.popitem() | ||
await shuffle.close() | ||
|
||
def raise_if_closed(self) -> None: | ||
if self.closed: | ||
raise ShuffleClosedError( | ||
f"{self.__class__.__name__} already closed on {self.worker.address}" | ||
) | ||
|
||
############################# | ||
# Methods for worker thread # | ||
############################# | ||
|
@@ -507,8 +572,11 @@ def get_output_partition( | |
|
||
Calling this for a ``shuffle_id`` which is unknown or incomplete is an error. | ||
""" | ||
assert shuffle_id in self.shuffles, "Shuffle worker restrictions misbehaving" | ||
shuffle = self.shuffles[shuffle_id] | ||
self.raise_if_closed() | ||
assert ( | ||
shuffle_id in self.shuffles or shuffle_id in self.erred_shuffles | ||
), "Shuffle worker restrictions misbehaving" | ||
shuffle = self.get_shuffle(shuffle_id) | ||
output = sync(self.worker.loop, shuffle.get_output_partition, output_partition) | ||
# key missing if another thread got to it first | ||
if shuffle.done() and shuffle_id in self.shuffles: | ||
|
@@ -517,7 +585,7 @@ def get_output_partition( | |
return output | ||
|
||
|
||
class ShuffleSchedulerExtension: | ||
class ShuffleSchedulerExtension(SchedulerPlugin): | ||
""" | ||
Shuffle extension for the scheduler | ||
|
||
|
@@ -536,6 +604,7 @@ class ShuffleSchedulerExtension: | |
columns: dict[ShuffleId, str] | ||
output_workers: dict[ShuffleId, set[str]] | ||
completed_workers: dict[ShuffleId, set[str]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to remove |
||
erred_shuffles: dict[ShuffleId, str] | ||
|
||
def __init__(self, scheduler: Scheduler): | ||
self.scheduler = scheduler | ||
|
@@ -551,6 +620,11 @@ def __init__(self, scheduler: Scheduler): | |
self.columns = {} | ||
self.output_workers = {} | ||
self.completed_workers = {} | ||
self.erred_shuffles = {} | ||
self.scheduler.add_plugin(self) | ||
|
||
def shuffle_ids(self) -> set[ShuffleId]: | ||
return set(self.worker_for) | ||
|
||
def heartbeat(self, ws: WorkerState, data: dict) -> None: | ||
for shuffle_id, d in data.items(): | ||
|
@@ -563,6 +637,9 @@ def get( | |
column: str | None, | ||
npartitions: int | None, | ||
) -> dict: | ||
if id in self.erred_shuffles: | ||
return {"status": "ERROR", "worker": self.erred_shuffles[id]} | ||
|
||
if id not in self.worker_for: | ||
assert schema is not None | ||
assert column is not None | ||
|
@@ -590,12 +667,40 @@ def get( | |
self.completed_workers[id] = set() | ||
|
||
return { | ||
"status": "OK", | ||
"worker_for": self.worker_for[id], | ||
"column": self.columns[id], | ||
"schema": self.schemas[id], | ||
"output_workers": self.output_workers[id], | ||
} | ||
|
||
async def remove_worker(self, scheduler: Scheduler, worker: str) -> None: | ||
broadcasts = [] | ||
for shuffle_id, output_workers in self.output_workers.items(): | ||
if worker not in output_workers: | ||
continue | ||
self.erred_shuffles[shuffle_id] = worker | ||
contact_workers = output_workers.copy() | ||
contact_workers.discard(worker) | ||
message = f"Worker {worker} left during active shuffle {shuffle_id}" | ||
broadcasts.append( | ||
scheduler.broadcast( | ||
msg={ | ||
"op": "shuffle_fail", | ||
"message": message, | ||
"shuffle_id": shuffle_id, | ||
}, | ||
workers=list(contact_workers), | ||
fjetter marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
) | ||
self.scheduler.handle_task_erred( | ||
f"shuffle-barrier-{shuffle_id}", | ||
exception=to_serialize(RuntimeError(message)), | ||
stimulus_id="shuffle-remove-worker", | ||
) | ||
await asyncio.gather(*broadcasts, return_exceptions=True) | ||
hendrikmakait marked this conversation as resolved.
Show resolved
Hide resolved
hendrikmakait marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: Clean up scheduler | ||
|
||
def register_complete(self, id: ShuffleId, worker: str) -> None: | ||
"""Learn from a worker that it has completed all reads of a shuffle""" | ||
if id not in self.completed_workers: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting place. Why would we need to raise here but not between any of the other awaits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I think we should combine the above calls into a single offload anyhow which would render this comment moot
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offload
itself is protected withraise_if_closed()
. I've been thinking whether I should wrap any async functionality that needs to be protected withraise_if_closed()
into individual functions. That would probably make reasoning about these easier.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, done.