Skip to content

Commit

Permalink
Merge branch 'main' into ensure_communicating
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 6, 2022
2 parents ef56014 + 70e5c90 commit 361b734
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 24 deletions.
12 changes: 6 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,12 +2202,6 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False):
worker_msgs,
) # don't try to recreate

for dts in ts.waiters:
if dts.state in ("no-worker", "processing"):
recommendations[dts.key] = "waiting"
elif dts.state == "waiting":
dts.waiting_on.add(ts)

# XXX factor this out?
worker_msg = {
"op": "free-keys",
Expand All @@ -2232,6 +2226,12 @@ def transition_memory_released(self, key, stimulus_id, safe: bool = False):
elif ts.who_wants or ts.waiters:
recommendations[key] = "waiting"

for dts in ts.waiters:
if dts.state in ("no-worker", "processing"):
recommendations[dts.key] = "waiting"
elif dts.state == "waiting":
dts.waiting_on.add(ts)

if self.validate:
assert not ts.waiting_on

Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ async def test_worker_stream_died_during_comm(c, s, a, b):
assert any("receive-dep-failed" in msg for msg in b.log)


@gen_cluster(client=True, nthreads=[("", 1)], timeout=4)
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_flight_to_executing_via_cancelled_resumed(c, s, b):

block_get_data = asyncio.Lock()
Expand Down
5 changes: 4 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3161,7 +3161,10 @@ async def test_task_flight_compute_oserror(c, s, a, b):
# inc is lost and needs to be recomputed. Therefore, sum is released
("free-keys", ("f1",)),
("f1", "release-key"),
("f1", "waiting", "released", "released", {"f1": "forgotten"}),
# The recommendations here are hard to predict. Whatever key is
# currently scheduled to be fetched, if any, will be recommended to be
# released.
("f1", "waiting", "released", "released", lambda msg: msg["f1"] == "forgotten"),
("f1", "released", "forgotten", "forgotten", {}),
# Now, we actually compute the task *once*. This must not cycle back
("f1", "compute-task"),
Expand Down
66 changes: 66 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from itertools import chain

import pytest

from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import assert_story, gen_cluster, inc
from distributed.worker_state_machine import (
ExecuteFailureEvent,
ExecuteSuccessEvent,
Expand All @@ -19,6 +21,11 @@
)


async def wait_for_state(key, state, dask_worker):
while key not in dask_worker.tasks or dask_worker.tasks[key].state != state:
await asyncio.sleep(0.005)


def test_TaskState_get_nbytes():
assert TaskState("x", nbytes=123).get_nbytes() == 123
# Default to distributed.scheduler.default-data-size
Expand Down Expand Up @@ -236,3 +243,62 @@ def test_executefailure_to_dict():
assert ev3.traceback is None
assert ev3.exception_text == "exc text"
assert ev3.traceback_text == "tb text"


@gen_cluster(client=True)
async def test_fetch_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)
f2 = c.submit(inc, f1, workers=[b.address], key="f2")

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f2

assert_story(
b.log,
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
[
(f2.key, "compute-task"),
# This is a "please fetch" request. We don't have anything like
# this, yet. We don't see the request-dep signal in here because we
# do not wait for the key to be actually scheduled
(f1.key, "ensure-task-exists", "released"),
# After the worker failed, we're instructed to forget f2 before
# something new comes in
("free-keys", (f2.key,)),
(f1.key, "compute-task"),
(f1.key, "put-in-memory"),
(f2.key, "compute-task"),
],
)


@gen_cluster(client=True)
async def test_fetch_via_amm_to_compute(c, s, a, b):
# Block ensure_communicating to ensure we indeed know that the task is in
# fetch and doesn't leave it accidentally
old_out_connections, b.total_out_connections = b.total_out_connections, 0
old_comm_threshold, b.comm_threshold_bytes = b.comm_threshold_bytes, 0

f1 = c.submit(inc, 1, workers=[a.address], key="f1", allow_other_workers=True)

await f1
s.request_acquire_replicas(b.address, [f1.key], stimulus_id="test")

await wait_for_state(f1.key, "fetch", b)
await a.close()

b.total_out_connections = old_out_connections
b.comm_threshold_bytes = old_comm_threshold

await f1
46 changes: 30 additions & 16 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def __init__(
self._transitions_table = {
("cancelled", "fetch"): self.transition_cancelled_fetch,
("cancelled", "released"): self.transition_cancelled_released,
("cancelled", "missing"): self.transition_cancelled_released,
("cancelled", "waiting"): self.transition_cancelled_waiting,
("cancelled", "forgotten"): self.transition_cancelled_forgotten,
("cancelled", "memory"): self.transition_cancelled_memory,
Expand All @@ -628,7 +629,7 @@ def __init__(
("executing", "released"): self.transition_executing_released,
("executing", "rescheduled"): self.transition_executing_rescheduled,
("fetch", "flight"): self.transition_fetch_flight,
("fetch", "missing"): self.transition_fetch_missing,
("fetch", "missing"): self.transition_generic_missing,
("fetch", "released"): self.transition_generic_released,
("flight", "error"): self.transition_flight_error,
("flight", "fetch"): self.transition_flight_fetch,
Expand All @@ -648,6 +649,7 @@ def __init__(
("ready", "released"): self.transition_generic_released,
("released", "error"): self.transition_generic_error,
("released", "fetch"): self.transition_released_fetch,
("released", "missing"): self.transition_released_fetch,
("released", "forgotten"): self.transition_released_forgotten,
("released", "memory"): self.transition_released_memory,
("released", "waiting"): self.transition_released_waiting,
Expand Down Expand Up @@ -2028,6 +2030,7 @@ def transition_missing_fetch(
if self.validate:
assert ts.state == "missing"
assert ts.priority is not None
assert ts.who_has

self._missing_dep_flight.discard(ts)
ts.state = "fetch"
Expand All @@ -2053,7 +2056,7 @@ def transition_flight_missing(
ts.done = False
return {}, []

def transition_fetch_missing(
def transition_generic_missing(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
ts.state = "missing"
Expand All @@ -2067,6 +2070,8 @@ def transition_released_fetch(
if self.validate:
assert ts.state == "released"
assert ts.priority is not None
if not ts.who_has:
return {ts: "missing"}, []
ts.state = "fetch"
ts.done = False
return self._add_to_data_needed(ts)
Expand Down Expand Up @@ -2654,18 +2659,25 @@ def _transition(
recs, instructions = self._transition(
ts, "released", stimulus_id=stimulus_id
)
v = recs.get(ts, (finish, *args))
v_state: str
v_args: list | tuple
if isinstance(v, tuple):
v_state, *v_args = v
else:
v_state, v_args = v, ()
b_recs, b_instructions = self._transition(
ts, v_state, *v_args, stimulus_id=stimulus_id
while v := recs.pop(ts, None):
if isinstance(v, tuple):
v_state, *v_args = v
else:
v_state, v_args = v, ()
if v_state == "forgotten":
# We do not want to forget. The purpose of this
# transition path is to get to `finish`
continue
recs, instructions = merge_recs_instructions(
(recs, instructions),
self._transition(ts, v_state, *v_args, stimulus_id=stimulus_id),
)
recs, instructions = merge_recs_instructions(
(recs, instructions),
self._transition(ts, finish, *args, stimulus_id=stimulus_id),
)
recs.update(b_recs)
instructions += b_instructions
except InvalidTransition:
self.log_event(
"invalid-worker-transition",
Expand Down Expand Up @@ -3241,7 +3253,12 @@ async def gather_dep(
for d in has_what:
ts = self.tasks[d]
ts.who_has.remove(worker)
if not ts.who_has and ts.state not in ("released", "memory"):
if not ts.who_has and ts.state in (
"fetch",
"flight",
"resumed",
"cancelled",
):
recommendations[ts] = "missing"
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
Expand Down Expand Up @@ -3302,10 +3319,7 @@ async def gather_dep(
"stimulus_id": stimulus_id,
}
)
if ts.who_has:
recommendations[ts] = "fetch"
elif ts.state not in ("released", "memory"):
recommendations[ts] = "missing"
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)

Expand Down

0 comments on commit 361b734

Please sign in to comment.