Skip to content

Commit df1eaba

Browse files
authored
Refactor gather_dep (#6388)
1 parent 6272e20 commit df1eaba

File tree

5 files changed

+263
-95
lines changed

5 files changed

+263
-95
lines changed

distributed/tests/test_stories.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,7 @@ async def test_worker_story_with_deps(c, s, a, b):
138138
# Story now includes randomized stimulus_ids and timestamps.
139139
story = b.story("res")
140140
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
141-
assert stimulus_ids == {"compute-task", "task-finished"}
142-
141+
assert stimulus_ids == {"compute-task", "gather-dep-success", "task-finished"}
143142
# This is a simple transition log
144143
expected = [
145144
("res", "compute-task", "released"),
@@ -153,7 +152,7 @@ async def test_worker_story_with_deps(c, s, a, b):
153152

154153
story = b.story("dep")
155154
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
156-
assert stimulus_ids == {"compute-task"}
155+
assert stimulus_ids == {"compute-task", "gather-dep-success"}
157156
expected = [
158157
("dep", "ensure-task-exists", "released"),
159158
("dep", "released", "fetch", "fetch", {}),

distributed/tests/test_worker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2928,7 +2928,6 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers):
29282928
coming_from.handle_stimulus(RemoveReplicasEvent(keys=[f1.key], stimulus_id="test"))
29292929
await f2
29302930

2931-
assert_story(a.story(f1.key), [(f1.key, "missing-dep")])
29322931
assert a.tasks[f1.key].suspicious_count == 0
29332932
assert s.tasks[f1.key].suspicious == 0
29342933

distributed/tests/test_worker_state_machine.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,38 @@ async def test_fetch_to_missing_on_refresh_who_has(c, s, w1, w2, w3):
647647
assert w3.tasks["x"].state == "missing"
648648
assert w3.tasks["y"].state == "flight"
649649
assert w3.tasks["y"].who_has == {w2.address}
650+
651+
652+
@gen_cluster(client=True, nthreads=[("", 1)])
653+
async def test_fetch_to_missing_on_network_failure(c, s, a):
654+
"""
655+
1. Two tasks, x and y, are respectively in flight and fetch state from the same
656+
worker, which holds the only replica of both.
657+
2. gather_dep for x returns GatherDepNetworkFailureEvent
658+
3. The event empties has_what, x.who_has, and y.who_has.
659+
4. The same event invokes _ensure_communicating, which pops y from data_needed
660+
- but y has an empty who_has, which is an exceptional situation.
661+
_ensure_communicating recommends a transition to missing for x.
662+
5. The fetch->missing transition is executed, but y is no longer in data_needed -
663+
another exceptional situation.
664+
"""
665+
block_get_data = asyncio.Event()
666+
667+
class BlockedBreakingWorker(Worker):
668+
async def get_data(self, comm, *args, **kwargs):
669+
await block_get_data.wait()
670+
raise OSError("fake error")
671+
672+
async with BlockedBreakingWorker(s.address) as b:
673+
x = c.submit(inc, 1, key="x", workers=[b.address])
674+
y = c.submit(inc, 2, key="y", workers=[b.address])
675+
await wait([x, y])
676+
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test_x")
677+
await wait_for_state("x", "flight", a)
678+
s.request_acquire_replicas(a.address, ["y"], stimulus_id="test_y")
679+
await wait_for_state("y", "fetch", a)
680+
681+
block_get_data.set()
682+
683+
await wait_for_state("x", "missing", a)
684+
await wait_for_state("y", "missing", a)

distributed/worker.py

Lines changed: 148 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Collection,
2222
Container,
2323
Iterable,
24+
Iterator,
2425
Mapping,
2526
MutableMapping,
2627
)
@@ -122,7 +123,11 @@
122123
FindMissingEvent,
123124
FreeKeysEvent,
124125
GatherDep,
126+
GatherDepBusyEvent,
125127
GatherDepDoneEvent,
128+
GatherDepFailureEvent,
129+
GatherDepNetworkFailureEvent,
130+
GatherDepSuccessEvent,
126131
Instructions,
127132
InvalidTaskState,
128133
InvalidTransition,
@@ -2185,13 +2190,7 @@ def transition_fetch_flight(
21852190
def transition_fetch_missing(
21862191
self, ts: TaskState, *, stimulus_id: str
21872192
) -> RecsInstrs:
2188-
# There's a use case where ts won't be found in self.data_needed, so
2189-
# `self.data_needed.remove(ts)` would crash:
2190-
# 1. An event handler empties who_has and pushes a recommendation to missing
2191-
# 2. The same event handler calls _ensure_communicating, which pops the task
2192-
# from data_needed
2193-
# 3. The recommendation is enacted
2194-
# See matching code in _ensure_communicating.
2193+
# _ensure_communicating could have just popped this task out of data_needed
21952194
self.data_needed.discard(ts)
21962195
return self.transition_generic_missing(ts, stimulus_id=stimulus_id)
21972196

@@ -3017,11 +3016,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
30173016
assert self.address not in ts.who_has
30183017

30193018
if not ts.who_has:
3020-
# An event handler just emptied who_has and recommended a fetch->missing
3021-
# transition. Then, the same handler called _ensure_communicating. The
3022-
# transition hasn't been enacted yet, so the task is still in fetch
3023-
# state and in data_needed.
3024-
# See matching code in transition_fetch_missing.
3019+
recommendations[ts] = "missing"
30253020
continue
30263021

30273022
workers = [
@@ -3293,13 +3288,6 @@ async def gather_dep(
32933288
if self.status not in WORKER_ANY_RUNNING:
32943289
return None
32953290

3296-
recommendations: Recs = {}
3297-
instructions: Instructions = []
3298-
response = {}
3299-
3300-
def done_event():
3301-
return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}")
3302-
33033291
try:
33043292
self.log.append(("request-dep", worker, to_gather, stimulus_id, time()))
33053293
logger.debug("Request %d keys from %s", len(to_gather), worker)
@@ -3310,8 +3298,14 @@ def done_event():
33103298
)
33113299
stop = time()
33123300
if response["status"] == "busy":
3313-
return done_event()
3301+
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
3302+
return GatherDepBusyEvent(
3303+
worker=worker,
3304+
total_nbytes=total_nbytes,
3305+
stimulus_id=f"gather-dep-busy-{time()}",
3306+
)
33143307

3308+
assert response["status"] == "OK"
33153309
cause = self._get_cause(to_gather)
33163310
self._update_metrics_received_data(
33173311
start=start,
@@ -3323,86 +3317,156 @@ def done_event():
33233317
self.log.append(
33243318
("receive-dep", worker, set(response["data"]), stimulus_id, time())
33253319
)
3326-
return done_event()
3320+
return GatherDepSuccessEvent(
3321+
worker=worker,
3322+
total_nbytes=total_nbytes,
3323+
data=response["data"],
3324+
stimulus_id=f"gather-dep-success-{time()}",
3325+
)
33273326

33283327
except OSError:
33293328
logger.exception("Worker stream died during communication: %s", worker)
3330-
has_what = self.has_what.pop(worker)
3331-
self.data_needed_per_worker.pop(worker)
33323329
self.log.append(
3333-
("receive-dep-failed", worker, has_what, stimulus_id, time())
3330+
("receive-dep-failed", worker, to_gather, stimulus_id, time())
3331+
)
3332+
return GatherDepNetworkFailureEvent(
3333+
worker=worker,
3334+
total_nbytes=total_nbytes,
3335+
stimulus_id=f"gather-dep-network-failure-{time()}",
33343336
)
3335-
for d in has_what:
3336-
ts = self.tasks[d]
3337-
ts.who_has.remove(worker)
3338-
if not ts.who_has and ts.state in (
3339-
"fetch",
3340-
"flight",
3341-
"resumed",
3342-
"cancelled",
3343-
):
3344-
recommendations[ts] = "missing"
3345-
self.log.append(
3346-
("missing-who-has", worker, ts.key, stimulus_id, time())
3347-
)
3348-
return done_event()
33493337

33503338
except Exception as e:
3339+
# e.g. data failed to deserialize
33513340
logger.exception(e)
33523341
if self.batched_stream and LOG_PDB:
33533342
import pdb
33543343

33553344
pdb.set_trace()
3356-
msg = error_message(e)
3357-
for k in self.in_flight_workers[worker]:
3358-
ts = self.tasks[k]
3359-
recommendations[ts] = tuple(msg.values())
3360-
return done_event()
33613345

3362-
finally:
3363-
self.comm_nbytes -= total_nbytes
3364-
busy = response.get("status", "") == "busy"
3365-
data = response.get("data", {})
3346+
return GatherDepFailureEvent.from_exception(
3347+
e,
3348+
worker=worker,
3349+
total_nbytes=total_nbytes,
3350+
stimulus_id=f"gather-dep-failure-{time()}",
3351+
)
33663352

3367-
if busy:
3368-
self.log.append(("busy-gather", worker, to_gather, stimulus_id, time()))
3369-
# Avoid hammering the worker. If there are multiple replicas
3370-
# available, immediately try fetching from a different worker.
3371-
self.busy_workers.add(worker)
3372-
instructions.append(
3373-
RetryBusyWorkerLater(worker=worker, stimulus_id=stimulus_id)
3374-
)
3353+
def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]:
3354+
"""Common code for all subclasses of GatherDepDoneEvent.
33753355
3376-
refresh_who_has = []
3377-
3378-
for d in self.in_flight_workers.pop(worker):
3379-
ts = self.tasks[d]
3380-
ts.done = True
3381-
if d in data:
3382-
recommendations[ts] = ("memory", data[d])
3383-
elif busy:
3384-
recommendations[ts] = "fetch"
3385-
if not ts.who_has - self.busy_workers:
3386-
refresh_who_has.append(d)
3387-
elif ts not in recommendations:
3388-
ts.who_has.discard(worker)
3389-
self.has_what[worker].discard(ts.key)
3390-
self.data_needed_per_worker[worker].discard(ts)
3391-
self.log.append((d, "missing-dep", stimulus_id, time()))
3392-
recommendations[ts] = "fetch"
3393-
3394-
if refresh_who_has:
3395-
# All workers that hold known replicas of our tasks are busy.
3396-
# Try querying the scheduler for unknown ones.
3397-
instructions.append(
3398-
RequestRefreshWhoHasMsg(
3399-
keys=refresh_who_has,
3400-
stimulus_id=f"gather-dep-busy-{time()}",
3401-
)
3356+
Yields the tasks that need to transition out of flight.
3357+
"""
3358+
self.comm_nbytes -= ev.total_nbytes
3359+
keys = self.in_flight_workers.pop(ev.worker)
3360+
for key in keys:
3361+
ts = self.tasks[key]
3362+
ts.done = True
3363+
yield ts
3364+
3365+
@_handle_event.register
3366+
def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
3367+
"""gather_dep terminated successfully.
3368+
The response may contain less keys than the request.
3369+
"""
3370+
recommendations: Recs = {}
3371+
for ts in self._gather_dep_done_common(ev):
3372+
if ts.key in ev.data:
3373+
recommendations[ts] = ("memory", ev.data[ts.key])
3374+
else:
3375+
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
3376+
if self.validate:
3377+
assert ts.state != "fetch"
3378+
assert ts not in self.data_needed_per_worker[ev.worker]
3379+
ts.who_has.discard(ev.worker)
3380+
self.has_what[ev.worker].discard(ts.key)
3381+
recommendations[ts] = "fetch"
3382+
3383+
return merge_recs_instructions(
3384+
(recommendations, []),
3385+
self._ensure_communicating(stimulus_id=ev.stimulus_id),
3386+
)
3387+
3388+
@_handle_event.register
3389+
def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs:
3390+
"""gather_dep terminated: remote worker is busy"""
3391+
# Avoid hammering the worker. If there are multiple replicas
3392+
# available, immediately try fetching from a different worker.
3393+
self.busy_workers.add(ev.worker)
3394+
3395+
recommendations: Recs = {}
3396+
refresh_who_has = []
3397+
for ts in self._gather_dep_done_common(ev):
3398+
recommendations[ts] = "fetch"
3399+
if not ts.who_has - self.busy_workers:
3400+
refresh_who_has.append(ts.key)
3401+
3402+
instructions: Instructions = [
3403+
RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id),
3404+
]
3405+
3406+
if refresh_who_has:
3407+
# All workers that hold known replicas of our tasks are busy.
3408+
# Try querying the scheduler for unknown ones.
3409+
instructions.append(
3410+
RequestRefreshWhoHasMsg(
3411+
keys=refresh_who_has, stimulus_id=ev.stimulus_id
34023412
)
3413+
)
34033414

3404-
self.transitions(recommendations, stimulus_id=stimulus_id)
3405-
self._handle_instructions(instructions)
3415+
return merge_recs_instructions(
3416+
(recommendations, instructions),
3417+
self._ensure_communicating(stimulus_id=ev.stimulus_id),
3418+
)
3419+
3420+
@_handle_event.register
3421+
def _handle_gather_dep_network_failure(
3422+
self, ev: GatherDepNetworkFailureEvent
3423+
) -> RecsInstrs:
3424+
"""gather_dep terminated: network failure while trying to
3425+
communicate with remote worker
3426+
3427+
Though the network failure could be transient, we assume it is not, and
3428+
preemptively act as though the other worker has died (including removing all
3429+
keys from it, even ones we did not fetch).
3430+
3431+
This optimization leads to faster completion of the fetch, since we immediately
3432+
either retry a different worker, or ask the scheduler to inform us of a new
3433+
worker if no other worker is available.
3434+
"""
3435+
self.data_needed_per_worker.pop(ev.worker)
3436+
for key in self.has_what.pop(ev.worker):
3437+
ts = self.tasks[key]
3438+
ts.who_has.discard(ev.worker)
3439+
3440+
recommendations: Recs = {}
3441+
for ts in self._gather_dep_done_common(ev):
3442+
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
3443+
recommendations[ts] = "fetch"
3444+
3445+
return merge_recs_instructions(
3446+
(recommendations, []),
3447+
self._ensure_communicating(stimulus_id=ev.stimulus_id),
3448+
)
3449+
3450+
@_handle_event.register
3451+
def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
3452+
"""gather_dep terminated: generic error raised (not a network failure);
3453+
e.g. data failed to deserialize.
3454+
"""
3455+
recommendations: Recs = {
3456+
ts: (
3457+
"error",
3458+
ev.exception,
3459+
ev.traceback,
3460+
ev.exception_text,
3461+
ev.traceback_text,
3462+
)
3463+
for ts in self._gather_dep_done_common(ev)
3464+
}
3465+
3466+
return merge_recs_instructions(
3467+
(recommendations, []),
3468+
self._ensure_communicating(stimulus_id=ev.stimulus_id),
3469+
)
34063470

34073471
async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent | None:
34083472
await asyncio.sleep(0.15)
@@ -3841,11 +3905,6 @@ def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs:
38413905
self._ensure_communicating(stimulus_id=ev.stimulus_id),
38423906
)
38433907

3844-
@_handle_event.register
3845-
def _handle_gather_dep_done(self, ev: GatherDepDoneEvent) -> RecsInstrs:
3846-
"""Temporary hack - to be removed"""
3847-
return self._ensure_communicating(stimulus_id=ev.stimulus_id)
3848-
38493908
@_handle_event.register
38503909
def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs:
38513910
self.busy_workers.discard(ev.worker)
@@ -4181,8 +4240,7 @@ def validate_task_fetch(self, ts):
41814240
assert self.address not in ts.who_has
41824241
assert not ts.done
41834242
assert ts in self.data_needed
4184-
assert ts.who_has
4185-
4243+
# Note: ts.who_has may be empty; see GatherDepNetworkFailureEvent
41864244
for w in ts.who_has:
41874245
assert ts.key in self.has_what[w]
41884246
assert ts in self.data_needed_per_worker[w]

0 commit comments

Comments
 (0)