Skip to content

Edge and impossible transitions to memory #7205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 2, 2022
117 changes: 29 additions & 88 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,7 +1850,12 @@ def _transition(

start = "released"
else:
raise RuntimeError(f"Impossible transition from {start} to {finish}")
# FIXME downcast antipattern
scheduler = cast(Scheduler, self)
raise RuntimeError(
f"Impossible transition from {start} to {finish} for {key!r}: "
f"{stimulus_id=}, {args=}, {kwargs=}, story={scheduler.story(ts)}"
)

if not stimulus_id:
stimulus_id = STIMULUS_ID_UNSET
Expand Down Expand Up @@ -2023,50 +2028,6 @@ def transition_no_worker_processing(self, key, stimulus_id):
pdb.set_trace()
raise

def transition_no_worker_memory(
self,
key: str,
stimulus_id: str,
*,
nbytes: int | None = None,
type: bytes | None = None,
typename: str | None = None,
worker: str,
**kwargs: Any,
):
try:
ws = self.workers[worker]
ts = self.tasks[key]
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}

if self.validate:
assert not ts.processing_on
assert not ts.waiting_on
assert ts.state == "no-worker"

self.unrunnable.remove(ts)

if nbytes is not None:
ts.set_nbytes(nbytes)

self.check_idle_saturated(ws)

_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
)
ts.state = "memory"

return recommendations, client_msgs, worker_msgs
except Exception as e:
logger.exception(e)
if LOG_PDB:
import pdb

pdb.set_trace()
raise

def decide_worker_rootish_queuing_disabled(
self, ts: TaskState
) -> WorkerState | None:
Expand Down Expand Up @@ -2292,35 +2253,23 @@ def transition_waiting_memory(
worker: str,
**kwargs: Any,
):
"""This transition exclusively happens in a race condition where the scheduler
believes that the only copy of a dependency task has just been lost, so it
transitions all dependents back to waiting, but actually a replica has already
been acquired by a worker computing the dependency - the scheduler just doesn't
know yet - and the execution finishes before the cancellation message from the
scheduler has a chance to reach the worker. Shortly, the cancellation request
will reach the worker, thus deleting the data from memory.
"""
try:
ws: WorkerState = self.workers[worker]
ts: TaskState = self.tasks[key]
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
ts = self.tasks[key]

if self.validate:
assert not ts.processing_on
assert ts.waiting_on
assert ts.state == "waiting"

ts.waiting_on.clear()

if nbytes is not None:
ts.set_nbytes(nbytes)

self.check_idle_saturated(ws)

_add_to_memory(
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
)

if self.validate:
assert not ts.processing_on
assert not ts.waiting_on
assert ts.who_has

return recommendations, client_msgs, worker_msgs
return {}, {}, {}
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -2365,21 +2314,15 @@ def transition_processing_memory(
if ws is None:
return {key: "released"}, {}, {}

if ws != ts.processing_on: # someone else has this task
logger.info(
"Unexpected worker completed task. Expected: %s, Got: %s, Key: %s",
ts.processing_on,
ws,
key,
)
if ws != ts.processing_on: # pragma: nocover
assert ts.processing_on
worker_msgs[ts.processing_on.address] = [
{
"op": "cancel-compute",
"key": key,
"stimulus_id": stimulus_id,
}
]
# FIXME downcast antipattern
scheduler = cast(Scheduler, self)
raise RuntimeError(
f"Task {ts.key!r} transitioned from processing to memory on worker "
f"{ws}, while it was expected from {ts.processing_on}. This should "
f"be impossible. {stimulus_id=}, story={scheduler.story(ts)}"
)

#############################
# Update Timing Information #
Expand Down Expand Up @@ -2650,7 +2593,7 @@ def transition_processing_released(self, key: str, stimulus_id: str):
}
]

_propagage_released(self, ts, recommendations)
_propagate_released(self, ts, recommendations)
return recommendations, {}, worker_msgs
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -2874,7 +2817,7 @@ def transition_queued_released(self, key, stimulus_id):

self.queued.remove(ts)

_propagage_released(self, ts, recommendations)
_propagate_released(self, ts, recommendations)
return recommendations, client_msgs, worker_msgs
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -3027,7 +2970,6 @@ def transition_released_forgotten(self, key, stimulus_id):
("processing", "erred"): transition_processing_erred,
("no-worker", "released"): transition_no_worker_released,
("no-worker", "processing"): transition_no_worker_processing,
("no-worker", "memory"): transition_no_worker_memory,
("released", "forgotten"): transition_released_forgotten,
("memory", "forgotten"): transition_memory_forgotten,
("erred", "released"): transition_erred_released,
Expand Down Expand Up @@ -7965,7 +7907,7 @@ def _add_to_memory(
)


def _propagage_released(
def _propagate_released(
state: SchedulerState,
ts: TaskState,
recommendations: Recs,
Expand Down Expand Up @@ -8319,10 +8261,9 @@ def heartbeat_interval(n: int) -> float:


def _task_slots_available(ws: WorkerState, saturation_factor: float) -> int:
"Number of tasks that can be sent to this worker without oversaturating it"
"""Number of tasks that can be sent to this worker without oversaturating it"""
assert not math.isinf(saturation_factor)
nthreads = ws.nthreads
return max(math.ceil(saturation_factor * nthreads), 1) - (
return max(math.ceil(saturation_factor * ws.nthreads), 1) - (
len(ws.processing) - len(ws.long_running)
)

Expand Down
42 changes: 42 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
NO_AMM,
BlockedGatherDep,
BrokenComm,
assert_story,
async_wait_for,
captured_logger,
cluster,
dec,
div,
freeze_batched_send,
freeze_data_fetching,
gen_cluster,
gen_test,
Expand Down Expand Up @@ -4099,3 +4101,43 @@ async def test_count_task_prefix(c, s, a, b):

assert s.task_prefixes["inc"].state_counts["memory"] == 20
assert s.task_prefixes["inc"].state_counts["erred"] == 0


@gen_cluster(client=True)
async def test_transition_waiting_memory(c, s, a, b):
"""Test race condition where a task transitions to memory while its state on the
scheduler is waiting:

1. worker a finishes x
2. y transitions to processing and is assigned to worker b
3. b fetches x and sends an add_keys message to the scheduler
4. In the meantime, a dies and causes x to be scheduled back to released/waiting.
5. Scheduler queues up a free-keys intended for b to cancel both x and y
6. Before free-keys arrives to b, the worker runs and completes y, sending a
finished-task message to the scheduler
7. {op: add-keys, keys=[x]} from b finally arrives to the scheduler. This triggers
a {op: remove-replicas, keys=[x]} message from the scheduler to worker b, because
add-keys when the task state is not memory triggers a cleanup of redundant
replicas (see Scheduler.add_keys) - in this, add-keys differs from task-finished!
8. {op: task-finished, key=y} from b arrives to the scheduler and it is ignored.
"""
x = c.submit(inc, 1, key="x", workers=[a.address])
y = c.submit(inc, x, key="y", workers=[b.address])
await wait_for_state("x", "memory", b, interval=0)
# Note interval=0 above. It means that x has just landed on b this instant and the
# scheduler doesn't know yet.
assert b.state.tasks["y"].state == "executing"
assert s.tasks["x"].who_has == {s.workers[a.address]}

with freeze_batched_send(b.batched_stream):
with freeze_batched_send(s.stream_comms[b.address]):
await s.remove_worker(a.address, stimulus_id="remove_a")
assert s.tasks["x"].state == "no-worker"
assert s.tasks["y"].state == "waiting"
await wait_for_state("y", "memory", b)

await async_wait_for(lambda: not b.state.tasks, timeout=5)

assert s.tasks["x"].state == "no-worker"
assert s.tasks["y"].state == "waiting"
assert_story(s.story("y"), [("y", "waiting", "waiting", {})])