33
44import pytest
55
6+ from distributed import Worker , wait
67from distributed .protocol .serialize import Serialize
78from distributed .utils import recursive_to_dict
89from distributed .utils_test import (
10+ BlockedGetData ,
911 _LockedCommPool ,
1012 assert_story ,
1113 freeze_data_fetching ,
2123 RescheduleMsg ,
2224 StateMachineEvent ,
2325 TaskState ,
26+ TaskStateState ,
2427 merge_recs_instructions ,
2528)
2629
2730
28- async def wait_for_state (key , state , dask_worker ) :
31+ async def wait_for_state (key : str , state : TaskStateState , dask_worker : Worker ) -> None :
2932 while key not in dask_worker .tasks or dask_worker .tasks [key ].state != state :
3033 await asyncio .sleep (0.005 )
3134
@@ -213,26 +216,17 @@ def test_executefailure_to_dict():
213216
214217@gen_cluster (client = True )
215218async def test_fetch_to_compute (c , s , a , b ):
216- # Block ensure_communicating to ensure we indeed know that the task is in
217- # fetch and doesn't leave it accidentally
218- old_out_connections , b .total_out_connections = b .total_out_connections , 0
219- old_comm_threshold , b .comm_threshold_bytes = b .comm_threshold_bytes , 0
220-
221- f1 = c .submit (inc , 1 , workers = [a .address ], key = "f1" , allow_other_workers = True )
222- f2 = c .submit (inc , f1 , workers = [b .address ], key = "f2" )
223-
224- await wait_for_state (f1 .key , "fetch" , b )
225- await a .close ()
226-
227- b .total_out_connections = old_out_connections
228- b .comm_threshold_bytes = old_comm_threshold
219+ with freeze_data_fetching (b ):
220+ f1 = c .submit (inc , 1 , workers = [a .address ], key = "f1" , allow_other_workers = True )
221+ f2 = c .submit (inc , f1 , workers = [b .address ], key = "f2" )
222+ await wait_for_state (f1 .key , "fetch" , b )
223+ await a .close ()
229224
230225 await f2
231226
232227 assert_story (
233228 b .log ,
234- # FIXME: This log should be replaced with an
235- # StateMachineEvent/Instruction log
229+ # FIXME: This log should be replaced with a StateMachineEvent log
236230 [
237231 (f2 .key , "compute-task" , "released" ),
238232 # This is a "please fetch" request. We don't have anything like
@@ -251,23 +245,180 @@ async def test_fetch_to_compute(c, s, a, b):
251245
252246@gen_cluster (client = True )
253247async def test_fetch_via_amm_to_compute (c , s , a , b ):
254- # Block ensure_communicating to ensure we indeed know that the task is in
255- # fetch and doesn't leave it accidentally
256- old_out_connections , b . total_out_connections = b . total_out_connections , 0
257- old_comm_threshold , b . comm_threshold_bytes = b . comm_threshold_bytes , 0
258-
259- f1 = c . submit ( inc , 1 , workers = [ a . address ], key = "f1" , allow_other_workers = True )
248+ with freeze_data_fetching ( b ):
249+ f1 = c . submit ( inc , 1 , workers = [ a . address ], key = "f1" , allow_other_workers = True )
250+ await f1
251+ s . request_acquire_replicas ( b . address , [ f1 . key ], stimulus_id = "test" )
252+ await wait_for_state ( f1 . key , "fetch" , b )
253+ await a . close ( )
260254
261255 await f1
262- s .request_acquire_replicas (b .address , [f1 .key ], stimulus_id = "test" )
263256
264- await wait_for_state (f1 .key , "fetch" , b )
265- await a .close ()
257+ assert_story (
258+ b .log ,
259+ # FIXME: This log should be replaced with a StateMachineEvent log
260+ [
261+ (f1 .key , "ensure-task-exists" , "released" ),
262+ (f1 .key , "released" , "fetch" , "fetch" , {}),
263+ (f1 .key , "compute-task" , "fetch" ),
264+ (f1 .key , "put-in-memory" ),
265+ ],
266+ )
267+
266268
267- b .total_out_connections = old_out_connections
268- b .comm_threshold_bytes = old_comm_threshold
269+ @pytest .mark .parametrize ("as_deps" , [False , True ])
270+ @gen_cluster (client = True , nthreads = [("" , 1 )] * 3 )
271+ async def test_lose_replica_during_fetch (c , s , w1 , w2 , w3 , as_deps ):
272+ """
273+ as_deps=True
274+ 0. task x is a dependency of y1 and y2
275+ 1. scheduler calls handle_compute("y1", who_has={"x": [w2, w3]}) on w1
276+ 2. x transitions released -> fetch
277+ 3. the network stack is busy, so x does not transition to flight yet.
278+ 4. scheduler calls handle_compute("y2", who_has={"x": [w3]}) on w1
279+ 5. when x finally reaches the top of the data_needed heap, w1 will not try
280+ contacting w2
281+
282+ as_deps=False
283+ 1. scheduler calls handle_acquire_replicas(who_has={"x": [w2, w3]}) on w1
284+ 2. x transitions released -> fetch
285+ 3. the network stack is busy, so x does not transition to flight yet.
286+ 4. scheduler calls handle_acquire_replicas(who_has={"x": [w3]}) on w1
287+ 5. when x finally reaches the top of the data_needed heap, w1 will not try
288+ contacting w2
289+ """
290+ x = (await c .scatter ({"x" : 1 }, workers = [w2 .address , w3 .address ], broadcast = True ))[
291+ "x"
292+ ]
269293
270- await f1
294+ # Make sure find_missing is not involved
295+ w1 .periodic_callbacks ["find-missing" ].stop ()
296+
297+ with freeze_data_fetching (w1 , jump_start = True ):
298+ if as_deps :
299+ y1 = c .submit (inc , x , key = "y1" , workers = [w1 .address ])
300+ else :
301+ s .request_acquire_replicas (w1 .address , ["x" ], stimulus_id = "test" )
302+
303+ await wait_for_state ("x" , "fetch" , w1 )
304+ assert w1 .tasks ["x" ].who_has == {w2 .address , w3 .address }
305+
306+ assert len (s .tasks ["x" ].who_has ) == 2
307+ await w2 .close ()
308+ while len (s .tasks ["x" ].who_has ) > 1 :
309+ await asyncio .sleep (0.01 )
310+
311+ if as_deps :
312+ y2 = c .submit (inc , x , key = "y2" , workers = [w1 .address ])
313+ else :
314+ s .request_acquire_replicas (w1 .address , ["x" ], stimulus_id = "test" )
315+
316+ while w1 .tasks ["x" ].who_has != {w3 .address }:
317+ await asyncio .sleep (0.01 )
318+
319+ await wait_for_state ("x" , "memory" , w1 )
320+ assert_story (
321+ w1 .story ("request-dep" ),
322+ [("request-dep" , w3 .address , {"x" })],
323+ # This tests that there has been no attempt to contact w2.
324+ # If the assumption being tested breaks, this will fail 50% of the times.
325+ strict = True ,
326+ )
327+
328+
329+ @gen_cluster (client = True , nthreads = [("" , 1 )] * 2 )
330+ async def test_fetch_to_missing (c , s , a , b ):
331+ """
332+ 1. task x is a dependency of y
333+ 2. scheduler calls handle_compute("y", who_has={"x": [b]}) on a
334+ 3. x transitions released -> fetch -> flight; a connects to b
335+ 4. b responds it's busy. x transitions flight -> fetch
336+ 5. The busy state triggers an RPC call to Scheduler.who_has
337+ 6. the scheduler responds {"x": []}, because w1 in the meantime has lost the key.
338+ 7. x is transitioned fetch -> missing
339+ """
340+ x = await c .scatter ({"x" : 1 }, workers = [b .address ])
341+ b .total_in_connections = 0
342+ # Crucially, unlike with `c.submit(inc, x, workers=[a.address])`, the scheduler
343+ # doesn't keep track of acquire-replicas requests, so it won't proactively inform a
344+ # when we call remove_worker later on
345+ s .request_acquire_replicas (a .address , ["x" ], stimulus_id = "test" )
346+
347+ # state will flip-flop between fetch and flight every 150ms, which is the retry
348+ # period for busy workers.
349+ await wait_for_state ("x" , "fetch" , a )
350+ assert b .address in a .busy_workers
351+
352+ # Sever connection between b and s, but not between b and a.
353+ # If a tries fetching from b after this, b will keep responding {status: busy}.
354+ b .periodic_callbacks ["heartbeat" ].stop ()
355+ await s .remove_worker (b .address , close = False , stimulus_id = "test" )
356+
357+ await wait_for_state ("x" , "missing" , a )
358+
359+ assert_story (
360+ a .story ("x" ),
361+ [
362+ ("x" , "ensure-task-exists" , "released" ),
363+ ("x" , "released" , "fetch" , "fetch" , {}),
364+ ("gather-dependencies" , b .address , {"x" }),
365+ ("x" , "fetch" , "flight" , "flight" , {}),
366+ ("request-dep" , b .address , {"x" }),
367+ ("busy-gather" , b .address , {"x" }),
368+ ("x" , "flight" , "fetch" , "fetch" , {}),
369+ ("x" , "fetch" , "missing" , "missing" , {}),
370+ ],
371+ # There may be a round of find_missing() after this.
372+ # Due to timings, there also may be multiple attempts to connect from a to b.
373+ strict = False ,
374+ )
375+
376+
377+ @pytest .mark .skip (reason = "https://github.com/dask/distributed/issues/6446" )
378+ @gen_cluster (client = True )
379+ async def test_new_replica_while_all_workers_in_flight (c , s , w1 , w2 ):
380+ """A task is stuck in 'fetch' state because all workers that hold a replica are in
381+ flight. While in this state, a new replica appears on a different worker and the
382+ scheduler informs the waiting worker through a new acquire-replicas or
383+ compute-task op.
384+
385+ In real life, this will typically happen when the Active Memory Manager replicates a
386+ key to multiple workers and some workers are much faster than others to acquire it,
387+ due to unrelated tasks being in flight, so 2 seconds later the AMM reiterates the
388+ request, passing a larger who_has.
389+
390+ Test that, when this happens, the task is immediately acquired from the new worker,
391+ without waiting for the original replica holders to get out of flight.
392+ """
393+ # Make sure find_missing is not involved
394+ w1 .periodic_callbacks ["find-missing" ].stop ()
395+
396+ async with BlockedGetData (s .address ) as w3 :
397+ x = c .submit (inc , 1 , key = "x" , workers = [w3 .address ])
398+ y = c .submit (inc , 2 , key = "y" , workers = [w3 .address ])
399+ await wait ([x , y ])
400+ s .request_acquire_replicas (w1 .address , ["x" ], stimulus_id = "test" )
401+ await w3 .in_get_data .wait ()
402+ assert w1 .tasks ["x" ].state == "flight"
403+ s .request_acquire_replicas (w1 .address , ["y" ], stimulus_id = "test" )
404+ # This cannot progress beyond fetch because w3 is already in flight
405+ await wait_for_state ("y" , "fetch" , w1 )
406+
407+ # Simulate that the AMM also requires that w2 acquires a replica of x.
408+ # The replica lands on w2 soon afterwards, while w3->w1 comms remain blocked by
409+ # unrelated transfers (x in our case).
410+ w2 .update_data ({"y" : 3 }, report = True )
411+ ws2 = s .workers [w2 .address ]
412+ while ws2 not in s .tasks ["y" ].who_has :
413+ await asyncio .sleep (0.01 )
414+
415+ # 2 seconds later, the AMM reiterates that w1 should acquire a replica of y
416+ s .request_acquire_replicas (w1 .address , ["y" ], stimulus_id = "test" )
417+ await wait_for_state ("y" , "memory" , w1 )
418+
419+ # Finally let the other worker to get out of flight
420+ w3 .block_get_data .set ()
421+ await wait_for_state ("x" , "memory" , w1 )
271422
272423
273424@gen_cluster (client = True )
0 commit comments