Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler worker reconnect drops messages #6341

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1561,7 +1561,6 @@ def _transitions(
"""
keys: set = set()
recommendations = recommendations.copy()
msgs: list
new_msgs: list
new: tuple
new_recs: dict
Expand All @@ -1576,13 +1575,13 @@ def _transitions(

recommendations.update(new_recs)
for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c) # type: ignore
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w) # type: ignore
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
Expand Down Expand Up @@ -3547,7 +3546,7 @@ def heartbeat_worker(
@log_errors
async def add_worker(
self,
comm=None,
comm,
*,
address: str,
status: str,
Expand Down Expand Up @@ -3585,8 +3584,7 @@ async def add_worker(
"message": "name taken, %s" % name,
"time": time(),
}
if comm:
await comm.write(msg)
await comm.write(msg)
return

self.log_event(address, {"action": "add-worker"})
Expand Down Expand Up @@ -3652,31 +3650,40 @@ async def add_worker(
except Exception as e:
logger.exception(e)

recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
if nbytes:
assert isinstance(nbytes, dict)
already_released_keys = []
for key in nbytes:
ts: TaskState = self.tasks.get(key) # type: ignore
ts: TaskState | None = self.tasks.get(key)
if ts is not None and ts.state != "released":
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
else:
t: tuple = self._transition(
recommendations, new_cmsgs, new_wmsgs = self._transition(
key,
"memory",
stimulus_id,
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is @fjetter's key fix. Notice we were overwriting client_msgs, worker_msgs on each iteration.

for c, new_msgs in new_cmsgs.items():
msgs = client_msgs.get(c)
if msgs is not None:
msgs.extend(new_msgs)
else:
client_msgs[c] = new_msgs
for w, new_msgs in new_wmsgs.items():
msgs = worker_msgs.get(w)
if msgs is not None:
msgs.extend(new_msgs)
else:
worker_msgs[w] = new_msgs
self._transitions(
recommendations, client_msgs, worker_msgs, stimulus_id
)
recommendations = {}
else:
already_released_keys.append(key)
if already_released_keys:
Expand All @@ -3691,10 +3698,12 @@ async def add_worker(
)

if ws.status == Status.running:
recommendations.update(self.bulk_schedule_after_adding_worker(ws))

if recommendations:
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self._transitions(
self.bulk_schedule_after_adding_worker(ws),
client_msgs,
worker_msgs,
stimulus_id,
)

self.send_all(client_msgs, worker_msgs)

Expand All @@ -3719,10 +3728,9 @@ async def add_worker(
)
msg.update(version_warning)

if comm:
await comm.write(msg)
await comm.write(msg)

await self.handle_worker(comm=comm, worker=address, stimulus_id=stimulus_id)
await self.handle_worker(comm, address, stimulus_id=stimulus_id)

async def add_nanny(self, comm):
msg = {
Expand Down Expand Up @@ -4803,7 +4811,7 @@ def handle_worker_status_change(
else:
self.running.discard(ws)

async def handle_worker(self, comm=None, worker=None, stimulus_id=None):
async def handle_worker(self, comm, worker: str, stimulus_id=None):
"""
Listen to responses from a single worker

Expand Down
47 changes: 30 additions & 17 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,25 +370,38 @@ async def test_clear_events_client_removal(c, s, a, b):
assert time() < start + 2


@gen_cluster()
async def test_add_worker(s, a, b):
w = Worker(s.address, nthreads=3)
w.data["x-5"] = 6
w.data["y"] = 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the test that probably should have caught the regression, but it didn't because Scheduler.add_worker only cares about nbytes. By just adding data, but not corresponding TaskStates, nbytes was empty in the worker's reconnection message:

nbytes={
ts.key: ts.get_nbytes()
for ts in self.tasks.values()
# Only if the task is in memory this is a sensible
# result since otherwise it simply submits the
# default value
if ts.state == "memory"
},

The keys field of this message is unused by the scheduler and should probably be removed. It should be redundant to nbytes anyway.

@gen_cluster(nthreads=[("", 1)], client=True)
async def test_add_worker(c, s, a):
lock = Lock()

dsk = {("x-%d" % i): (inc, i) for i in range(10)}
s.update_graph(
tasks=valmap(dumps_task, dsk),
keys=list(dsk),
client="client",
dependencies={k: set() for k in dsk},
)
s.validate_state()
await w
s.validate_state()
async with lock:
anywhere = c.submit(inc, 0, key="l-0")
l1 = c.submit(lock.acquire, key="l-1")
l2 = c.submit(lock.acquire, key="l-2")

while not (sum(t.state == "processing" for t in s.tasks.values()) == 3):
await asyncio.sleep(0.01)

# Simulate a worker joining with necessary and unnecessary data.
w = Worker(s.address, nthreads=1)
w.update_data({"l-1": 2, "l-2": 3, "x": -1, "y": -2})
# `update_data` queues messages to send; we want to purely test `add_worker` logic
w.batched_stream.buffer.clear()

s.validate_state()
await w
s.validate_state()

while not len(s.workers) == 2:
await asyncio.sleep(0.01)

assert w.ip in s.host_info
assert s.host_info[w.ip]["addresses"] == {a.address, w.address}

assert await c.gather([anywhere, l1, l2]) == [1, 2, 3]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this updated test fails on main (it times out waiting here, because the client message saying l1 and l2 are done is overwritten and never sent).

assert "x" not in w.data
assert "y" not in w.data

assert w.ip in s.host_info
assert s.host_info[w.ip]["addresses"] == {a.address, b.address, w.address}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous test wasn't really testing much besides this. Whether x-5 would come from normal compute on worker a, or from memory on the new worker, was a race condition, and was untested anyway.

await w.close()


Expand Down