Skip to content

Commit 71a31a7

Browse files
committed
Merge branch 'main' into WSMR/forgotten_data_needed
2 parents a032754 + 715d7be commit 71a31a7

File tree

4 files changed

+292
-90
lines changed

4 files changed

+292
-90
lines changed

distributed/tests/test_worker.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from distributed.protocol import pickle
4747
from distributed.scheduler import Scheduler
4848
from distributed.utils_test import (
49+
BlockedGatherDep,
50+
BlockedGetData,
4951
TaskStateMetadataPlugin,
5052
_LockedCommPool,
5153
assert_story,
@@ -3062,30 +3064,6 @@ async def test_task_flight_compute_oserror(c, s, a, b):
30623064
assert_story(sum_story, expected_sum_story, strict=True)
30633065

30643066

3065-
class BlockedGatherDep(Worker):
3066-
def __init__(self, *args, **kwargs):
3067-
self.in_gather_dep = asyncio.Event()
3068-
self.block_gather_dep = asyncio.Event()
3069-
super().__init__(*args, **kwargs)
3070-
3071-
async def gather_dep(self, *args, **kwargs):
3072-
self.in_gather_dep.set()
3073-
await self.block_gather_dep.wait()
3074-
return await super().gather_dep(*args, **kwargs)
3075-
3076-
3077-
class BlockedGetData(Worker):
3078-
def __init__(self, *args, **kwargs):
3079-
self.in_get_data = asyncio.Event()
3080-
self.block_get_data = asyncio.Event()
3081-
super().__init__(*args, **kwargs)
3082-
3083-
async def get_data(self, comm, *args, **kwargs):
3084-
self.in_get_data.set()
3085-
await self.block_get_data.wait()
3086-
return await super().get_data(comm, *args, **kwargs)
3087-
3088-
30893067
@gen_cluster(client=True, nthreads=[])
30903068
async def test_gather_dep_cancelled_rescheduled(c, s):
30913069
"""At time of writing, the gather_dep implementation filtered tasks again

distributed/tests/test_worker_state_machine.py

Lines changed: 179 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import pytest
55

6+
from distributed import Worker, wait
67
from distributed.protocol.serialize import Serialize
78
from distributed.utils import recursive_to_dict
89
from distributed.utils_test import (
10+
BlockedGetData,
911
_LockedCommPool,
1012
assert_story,
1113
freeze_data_fetching,
@@ -21,11 +23,12 @@
2123
RescheduleMsg,
2224
StateMachineEvent,
2325
TaskState,
26+
TaskStateState,
2427
merge_recs_instructions,
2528
)
2629

2730

28-
async def wait_for_state(key, state, dask_worker):
31+
async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -> None:
2932
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
3033
await asyncio.sleep(0.005)
3134

@@ -213,26 +216,17 @@ def test_executefailure_to_dict():
213216

214217
@gen_cluster(client=True)
215218
async def test_fetch_to_compute(c, s, a, b):
216-
# Block ensure_communicating to ensure we indeed know that the task is in
217-
# fetch and doesn't leave it accidentally
218-
old_out_connections, b.total_out_connections = b.total_out_connections, 0
219-
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
220-
221-
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
222-
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
223-
224-
await wait_for_state(f1.key, "fetch", b)
225-
await a.close()
226-
227-
b.total_out_connections = old_out_connections
228-
b.comm_threshold_bytes = old_comm_threshold
219+
with freeze_data_fetching(b):
220+
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
221+
f2 = c.submit(inc, f1, workers=[b.address], key="f2")
222+
await wait_for_state(f1.key, "fetch", b)
223+
await a.close()
229224

230225
await f2
231226

232227
assert_story(
233228
b.log,
234-
# FIXME: This log should be replaced with an
235-
# StateMachineEvent/Instruction log
229+
# FIXME: This log should be replaced with a StateMachineEvent log
236230
[
237231
(f2.key, "compute-task", "released"),
238232
# This is a "please fetch" request. We don't have anything like
@@ -251,23 +245,180 @@ async def test_fetch_to_compute(c, s, a, b):
251245

252246
@gen_cluster(client=True)
253247
async def test_fetch_via_amm_to_compute(c, s, a, b):
254-
# Block ensure_communicating to ensure we indeed know that the task is in
255-
# fetch and doesn't leave it accidentally
256-
old_out_connections, b.total_out_connections = b.total_out_connections, 0
257-
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0
258-
259-
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
248+
with freeze_data_fetching(b):
249+
f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
250+
await f1
251+
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
252+
await wait_for_state(f1.key, "fetch", b)
253+
await a.close()
260254

261255
await f1
262-
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")
263256

264-
await wait_for_state(f1.key, "fetch", b)
265-
await a.close()
257+
assert_story(
258+
b.log,
259+
# FIXME: This log should be replaced with a StateMachineEvent log
260+
[
261+
(f1.key, "ensure-task-exists", "released"),
262+
(f1.key, "released", "fetch", "fetch", {}),
263+
(f1.key, "compute-task", "fetch"),
264+
(f1.key, "put-in-memory"),
265+
],
266+
)
267+
266268

267-
b.total_out_connections = old_out_connections
268-
b.comm_threshold_bytes = old_comm_threshold
269+
@pytest.mark.parametrize("as_deps", [False, True])
270+
@gen_cluster(client=True, nthreads=[("", 1)] * 3)
271+
async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps):
272+
"""
273+
as_deps=True
274+
0. task x is a dependency of y1 and y2
275+
1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
276+
2. x transitions released -> fetch
277+
3. the network stack is busy, so x does not transition to flight yet.
278+
4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
279+
5. when x finally reaches the top of the data_needed heap, w1 will not try
280+
contacting w2
281+
282+
as_deps=False
283+
1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
284+
2. x transitions released -> fetch
285+
3. the network stack is busy, so x does not transition to flight yet.
286+
4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
287+
5. when x finally reaches the top of the data_needed heap, w1 will not try
288+
contacting w2
289+
"""
290+
x = (await c.scatter({"x": 1}, workers=[w2.address, w3.address], broadcast=True))[
291+
"x"
292+
]
269293

270-
await f1
294+
# Make sure find_missing is not involved
295+
w1.periodic_callbacks["find-missing"].stop()
296+
297+
with freeze_data_fetching(w1, jump_start=True):
298+
if as_deps:
299+
y1 = c.submit(inc, x, key="y1", workers=[w1.address])
300+
else:
301+
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
302+
303+
await wait_for_state("x", "fetch", w1)
304+
assert w1.tasks["x"].who_has == {w2.address, w3.address}
305+
306+
assert len(s.tasks["x"].who_has) == 2
307+
await w2.close()
308+
while len(s.tasks["x"].who_has) > 1:
309+
await asyncio.sleep(0.01)
310+
311+
if as_deps:
312+
y2 = c.submit(inc, x, key="y2", workers=[w1.address])
313+
else:
314+
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
315+
316+
while w1.tasks["x"].who_has != {w3.address}:
317+
await asyncio.sleep(0.01)
318+
319+
await wait_for_state("x", "memory", w1)
320+
assert_story(
321+
w1.story("request-dep"),
322+
[("request-dep", w3.address, {"x"})],
323+
# This tests that there has been no attempt to contact w2.
324+
# If the assumption being tested breaks, this will fail 50% of the times.
325+
strict=True,
326+
)
327+
328+
329+
@gen_cluster(client=True, nthreads=[("", 1)] * 2)
330+
async def test_fetch_to_missing(c, s, a, b):
331+
"""
332+
1. task x is a dependency of y
333+
2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
334+
3. x transitions released -> fetch -> flight; a connects to b
335+
4. b responds it's busy. x transitions flight -> fetch
336+
5. The busy state triggers an RPC call to Scheduler.who_has
337+
6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
338+
7. x is transitioned fetch -> missing
339+
"""
340+
x = await c.scatter({"x": 1}, workers=[b.address])
341+
b.total_in_connections = 0
342+
# Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
343+
# doesn't keep track of acquire-replicas requests, so it won't proactively inform a
344+
# when we call remove_worker later on
345+
s.request_acquire_replicas(a.address, ["x"], stimulus_id="test")
346+
347+
# state will flip-flop between fetch and flight every 150ms, which is the retry
348+
# period for busy workers.
349+
await wait_for_state("x", "fetch", a)
350+
assert b.address in a.busy_workers
351+
352+
# Sever connection between b and s, but not between b and a.
353+
# If a tries fetching from b after this, b will keep responding {status: busy}.
354+
b.periodic_callbacks["heartbeat"].stop()
355+
await s.remove_worker(b.address, close=False, stimulus_id="test")
356+
357+
await wait_for_state("x", "missing", a)
358+
359+
assert_story(
360+
a.story("x"),
361+
[
362+
("x", "ensure-task-exists", "released"),
363+
("x", "released", "fetch", "fetch", {}),
364+
("gather-dependencies", b.address, {"x"}),
365+
("x", "fetch", "flight", "flight", {}),
366+
("request-dep", b.address, {"x"}),
367+
("busy-gather", b.address, {"x"}),
368+
("x", "flight", "fetch", "fetch", {}),
369+
("x", "fetch", "missing", "missing", {}),
370+
],
371+
# There may be a round of find_missing() after this.
372+
# Due to timings, there also may be multiple attempts to connect from a to b.
373+
strict=False,
374+
)
375+
376+
377+
@pytest.mark.skip(reason="https://github.com/dask/distributed/issues/6446")
378+
@gen_cluster(client=True)
379+
async def test_new_replica_while_all_workers_in_flight(c, s, w1, w2):
380+
"""A task is stuck in 'fetch' state because all workers that hold a replica are in
381+
flight. While in this state, a new replica appears on a different worker and the
382+
scheduler informs the waiting worker through a new acquire-replicas or
383+
compute-task op.
384+
385+
In real life, this will typically happen when the Active Memory Manager replicates a
386+
key to multiple workers and some workers are much faster than others to acquire it,
387+
due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the
388+
request, passing a larger who_has.
389+
390+
Test that, when this happens, the task is immediately acquired from the new worker,
391+
without waiting for the original replica holders to get out of flight.
392+
"""
393+
# Make sure find_missing is not involved
394+
w1.periodic_callbacks["find-missing"].stop()
395+
396+
async with BlockedGetData(s.address) as w3:
397+
x = c.submit(inc, 1, key="x", workers=[w3.address])
398+
y = c.submit(inc, 2, key="y", workers=[w3.address])
399+
await wait([x, y])
400+
s.request_acquire_replicas(w1.address, ["x"], stimulus_id="test")
401+
await w3.in_get_data.wait()
402+
assert w1.tasks["x"].state == "flight"
403+
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
404+
# This cannot progress beyond fetch because w3 is already in flight
405+
await wait_for_state("y", "fetch", w1)
406+
407+
# Simulate that the AMM also requires that w2 acquires a replica of x.
408+
# The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by
409+
# unrelated transfers (x in our case).
410+
w2.update_data({"y": 3}, report=True)
411+
ws2 = s.workers[w2.address]
412+
while ws2 not in s.tasks["y"].who_has:
413+
await asyncio.sleep(0.01)
414+
415+
# 2 seconds later, the AMM reiterates that w1 should acquire a replica of y
416+
s.request_acquire_replicas(w1.address, ["y"], stimulus_id="test")
417+
await wait_for_state("y", "memory", w1)
418+
419+
# Finally let the other worker to get out of flight
420+
w3.block_get_data.set()
421+
await wait_for_state("x", "memory", w1)
271422

272423

273424
@gen_cluster(client=True)

distributed/utils_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,13 +2263,70 @@ def wait_for_log_line(
22632263
i += 1
22642264

22652265

2266+
class BlockedGatherDep(Worker):
2267+
"""A Worker that sets event `in_gather_dep` the first time it enters the gather_dep
2268+
method and then does not initiate any comms, thus leaving the task(s) in flight
2269+
indefinitely, until the test sets `block_gather_dep`
2270+
2271+
Example
2272+
-------
2273+
.. code-block:: python
2274+
2275+
@gen_test()
2276+
async def test1(s, a, b):
2277+
async with BlockedGatherDep(s.address) as x:
2278+
# [do something to cause x to fetch data from a or b]
2279+
await x.in_gather_dep.wait()
2280+
# [do something that must happen while the tasks are in flight]
2281+
x.block_gather_dep.set()
2282+
# [from this moment on, x is a regular worker]
2283+
2284+
See also
2285+
--------
2286+
BlockedGetData
2287+
"""
2288+
2289+
def __init__(self, *args, **kwargs):
2290+
self.in_gather_dep = asyncio.Event()
2291+
self.block_gather_dep = asyncio.Event()
2292+
super().__init__(*args, **kwargs)
2293+
2294+
async def gather_dep(self, *args, **kwargs):
2295+
self.in_gather_dep.set()
2296+
await self.block_gather_dep.wait()
2297+
return await super().gather_dep(*args, **kwargs)
2298+
2299+
2300+
class BlockedGetData(Worker):
2301+
"""A Worker that sets event `in_get_data` the first time it enters the get_data
2302+
method and then does not answer the comms, thus leaving the task(s) in flight
2303+
indefinitely, until the test sets `block_get_data`
2304+
2305+
See also
2306+
--------
2307+
BlockedGatherDep
2308+
"""
2309+
2310+
def __init__(self, *args, **kwargs):
2311+
self.in_get_data = asyncio.Event()
2312+
self.block_get_data = asyncio.Event()
2313+
super().__init__(*args, **kwargs)
2314+
2315+
async def get_data(self, comm, *args, **kwargs):
2316+
self.in_get_data.set()
2317+
await self.block_get_data.wait()
2318+
return await super().get_data(comm, *args, **kwargs)
2319+
2320+
22662321
@contextmanager
22672322
def freeze_data_fetching(w: Worker, *, jump_start: bool = False):
22682323
"""Prevent any task from transitioning from fetch to flight on the worker while
22692324
inside the context, simulating a situation where the worker's network comms are
22702325
saturated.
2326+
22712327
This is not the same as setting the worker to Status=paused, which would also
22722328
inform the Scheduler and prevent further tasks to be enqueued on the worker.
2329+
22732330
Parameters
22742331
----------
22752332
w: Worker

0 commit comments

Comments
 (0)