Skip to content

Commit df8e7df

Browse files
committed
Refactor worker state machine
1 parent 3d73623 commit df8e7df

17 files changed

+1741
-913
lines changed

distributed/cfexecutor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def map(self, fn, *iterables, **kwargs):
127127
raise TypeError("unexpected arguments to map(): %s" % sorted(kwargs))
128128

129129
fs = self._client.map(fn, *iterables, **self._kwargs)
130+
if isinstance(fs, list):
131+
# Below iterator relies on this being a generator to cancel
132+
# remaining futures
133+
fs = (val for val in fs)
130134

131135
# Yield must be hidden in closure so that the tasks are submitted
132136
# before the first iterator value is required.

distributed/diagnostics/plugin.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -157,24 +157,6 @@ def transition(self, key, start, finish, **kwargs):
157157
kwargs : More options passed when transitioning
158158
"""
159159

160-
def release_key(self, key, state, cause, reason, report):
161-
"""
162-
Called when the worker releases a task.
163-
164-
Parameters
165-
----------
166-
key : string
167-
state : string
168-
State of the released task.
169-
One of waiting, ready, executing, long-running, memory, error.
170-
cause : string or None
171-
Additional information on what triggered the release of the task.
172-
reason : None
173-
Not used.
174-
report : bool
175-
Whether the worker should report the released task to the scheduler.
176-
"""
177-
178160

179161
class NannyPlugin:
180162
"""Interface to extend the Nanny

distributed/diagnostics/tests/test_worker_plugin.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ def transition(self, key, start, finish, **kwargs):
3434
{"key": key, "start": start, "finish": finish}
3535
)
3636

37-
def release_key(self, key, state, cause, reason, report):
38-
self.observed_notifications.append({"key": key, "state": state})
39-
4037

4138
@gen_cluster(client=True, nthreads=[])
4239
async def test_create_with_client(c, s):
@@ -107,11 +104,12 @@ async def test_create_on_construction(c, s, a, b):
107104
@gen_cluster(nthreads=[("127.0.0.1", 1)], client=True)
108105
async def test_normal_task_transitions_called(c, s, w):
109106
expected_notifications = [
110-
{"key": "task", "start": "new", "finish": "waiting"},
107+
{"key": "task", "start": "released", "finish": "waiting"},
111108
{"key": "task", "start": "waiting", "finish": "ready"},
112109
{"key": "task", "start": "ready", "finish": "executing"},
113110
{"key": "task", "start": "executing", "finish": "memory"},
114-
{"key": "task", "state": "memory"},
111+
{"key": "task", "start": "memory", "finish": "released"},
112+
{"key": "task", "start": "released", "finish": "forgotten"},
115113
]
116114

117115
plugin = MyPlugin(1, expected_notifications=expected_notifications)
@@ -127,11 +125,12 @@ def failing(x):
127125
raise Exception()
128126

129127
expected_notifications = [
130-
{"key": "task", "start": "new", "finish": "waiting"},
128+
{"key": "task", "start": "released", "finish": "waiting"},
131129
{"key": "task", "start": "waiting", "finish": "ready"},
132130
{"key": "task", "start": "ready", "finish": "executing"},
133131
{"key": "task", "start": "executing", "finish": "error"},
134-
{"key": "task", "state": "error"},
132+
{"key": "task", "start": "error", "finish": "released"},
133+
{"key": "task", "start": "released", "finish": "forgotten"},
135134
]
136135

137136
plugin = MyPlugin(1, expected_notifications=expected_notifications)
@@ -147,11 +146,12 @@ def failing(x):
147146
)
148147
async def test_superseding_task_transitions_called(c, s, w):
149148
expected_notifications = [
150-
{"key": "task", "start": "new", "finish": "waiting"},
149+
{"key": "task", "start": "released", "finish": "waiting"},
151150
{"key": "task", "start": "waiting", "finish": "constrained"},
152151
{"key": "task", "start": "constrained", "finish": "executing"},
153152
{"key": "task", "start": "executing", "finish": "memory"},
154-
{"key": "task", "state": "memory"},
153+
{"key": "task", "start": "memory", "finish": "released"},
154+
{"key": "task", "start": "released", "finish": "forgotten"},
155155
]
156156

157157
plugin = MyPlugin(1, expected_notifications=expected_notifications)
@@ -166,16 +166,18 @@ async def test_dependent_tasks(c, s, w):
166166
dsk = {"dep": 1, "task": (inc, "dep")}
167167

168168
expected_notifications = [
169-
{"key": "dep", "start": "new", "finish": "waiting"},
169+
{"key": "dep", "start": "released", "finish": "waiting"},
170170
{"key": "dep", "start": "waiting", "finish": "ready"},
171171
{"key": "dep", "start": "ready", "finish": "executing"},
172172
{"key": "dep", "start": "executing", "finish": "memory"},
173-
{"key": "task", "start": "new", "finish": "waiting"},
173+
{"key": "task", "start": "released", "finish": "waiting"},
174174
{"key": "task", "start": "waiting", "finish": "ready"},
175175
{"key": "task", "start": "ready", "finish": "executing"},
176176
{"key": "task", "start": "executing", "finish": "memory"},
177-
{"key": "dep", "state": "memory"},
178-
{"key": "task", "state": "memory"},
177+
{"key": "dep", "start": "memory", "finish": "released"},
178+
{"key": "dep", "start": "released", "finish": "forgotten"},
179+
{"key": "task", "start": "memory", "finish": "released"},
180+
{"key": "task", "start": "released", "finish": "forgotten"},
179181
]
180182

181183
plugin = MyPlugin(1, expected_notifications=expected_notifications)
@@ -218,3 +220,50 @@ class MyCustomPlugin(WorkerPlugin):
218220
await c.register_worker_plugin(MyCustomPlugin())
219221
assert len(w.plugins) == 1
220222
assert next(iter(w.plugins)).startswith("MyCustomPlugin-")
223+
224+
225+
def test_release_key_deprecated():
226+
class ReleaseKeyDeprecated(WorkerPlugin):
227+
def __init__(self):
228+
self._called = False
229+
230+
def release_key(self, key, state, cause, reason, report):
231+
# Ensure that the handler still works
232+
self._called = True
233+
assert state == "memory"
234+
assert key == "task"
235+
236+
def teardown(self, worker):
237+
assert self._called
238+
return super().teardown(worker)
239+
240+
@gen_cluster(client=True, nthreads=[("", 1)])
241+
async def test(c, s, a):
242+
243+
await c.register_worker_plugin(ReleaseKeyDeprecated())
244+
fut = await c.submit(inc, 1, key="task")
245+
assert fut == 2
246+
247+
with pytest.deprecated_call(
248+
match="The `WorkerPlugin.release_key` hook is depreacted"
249+
):
250+
test()
251+
252+
253+
def test_assert_no_warning_no_overload():
254+
"""Assert we do not receive a deprecation warning if we do not overload any
255+
methods
256+
"""
257+
258+
class Dummy(WorkerPlugin):
259+
pass
260+
261+
@gen_cluster(client=True, nthreads=[("", 1)])
262+
async def test(c, s, a):
263+
264+
await c.register_worker_plugin(Dummy())
265+
fut = await c.submit(inc, 1, key="task")
266+
assert fut == 2
267+
268+
with pytest.warns(None):
269+
test()

distributed/scheduler.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,6 +1969,7 @@ def __init__(
19691969
("processing", "erred"): self.transition_processing_erred,
19701970
("no-worker", "released"): self.transition_no_worker_released,
19711971
("no-worker", "waiting"): self.transition_no_worker_waiting,
1972+
("no-worker", "memory"): self.transition_no_worker_memory,
19721973
("released", "forgotten"): self.transition_released_forgotten,
19731974
("memory", "forgotten"): self.transition_memory_forgotten,
19741975
("erred", "released"): self.transition_erred_released,
@@ -2215,7 +2216,7 @@ def _transition(self, key, finish: str, *args, **kwargs):
22152216
self._transition_counter += 1
22162217
recommendations, client_msgs, worker_msgs = a
22172218
elif "released" not in start_finish:
2218-
assert not args and not kwargs
2219+
assert not args and not kwargs, start_finish
22192220
a_recs: dict
22202221
a_cmsgs: dict
22212222
a_wmsgs: dict
@@ -2614,6 +2615,42 @@ def transition_waiting_processing(self, key):
26142615
pdb.set_trace()
26152616
raise
26162617

2618+
def transition_no_worker_memory(
2619+
self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
2620+
):
2621+
try:
2622+
ws: WorkerState = self._workers_dv[worker]
2623+
ts: TaskState = self._tasks[key]
2624+
recommendations: dict = {}
2625+
client_msgs: dict = {}
2626+
worker_msgs: dict = {}
2627+
2628+
if self._validate:
2629+
assert not ts._processing_on
2630+
assert not ts._waiting_on
2631+
assert ts._state == "no-worker"
2632+
2633+
self._unrunnable.remove(ts)
2634+
2635+
if nbytes is not None:
2636+
ts.set_nbytes(nbytes)
2637+
2638+
self.check_idle_saturated(ws)
2639+
2640+
_add_to_memory(
2641+
self, ts, ws, recommendations, client_msgs, type=type, typename=typename
2642+
)
2643+
ts.state = "memory"
2644+
2645+
return recommendations, client_msgs, worker_msgs
2646+
except Exception as e:
2647+
logger.exception(e)
2648+
if LOG_PDB:
2649+
import pdb
2650+
2651+
pdb.set_trace()
2652+
raise
2653+
26172654
def transition_waiting_memory(
26182655
self, key, nbytes=None, type=None, typename: str = None, worker=None, **kwargs
26192656
):
@@ -5353,6 +5390,8 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs):
53535390

53545391
def release_worker_data(self, comm=None, keys=None, worker=None):
53555392
parent: SchedulerState = cast(SchedulerState, self)
5393+
if worker not in parent._workers_dv:
5394+
return
53565395
ws: WorkerState = parent._workers_dv[worker]
53575396
tasks: set = {parent._tasks[k] for k in keys if k in parent._tasks}
53585397
removed_tasks: set = tasks.intersection(ws._has_what)
@@ -6610,7 +6649,7 @@ def add_keys(self, comm=None, worker=None, keys=()):
66106649
if worker not in parent._workers_dv:
66116650
return "not found"
66126651
ws: WorkerState = parent._workers_dv[worker]
6613-
superfluous_data = []
6652+
redundant_replicas = []
66146653
for key in keys:
66156654
ts: TaskState = parent._tasks.get(key)
66166655
if ts is not None and ts._state == "memory":
@@ -6619,14 +6658,15 @@ def add_keys(self, comm=None, worker=None, keys=()):
66196658
ws._has_what[ts] = None
66206659
ts._who_has.add(ws)
66216660
else:
6622-
superfluous_data.append(key)
6623-
if superfluous_data:
6661+
redundant_replicas.append(key)
6662+
6663+
if redundant_replicas:
66246664
self.worker_send(
66256665
worker,
66266666
{
6627-
"op": "superfluous-data",
6628-
"keys": superfluous_data,
6629-
"reason": f"Add keys which are not in-memory {superfluous_data}",
6667+
"op": "remove-replicas",
6668+
"keys": redundant_replicas,
6669+
"stimulus_id": f"redundant-replicas-{time()}",
66306670
},
66316671
)
66326672

@@ -7734,6 +7774,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
77347774
"key": ts._key,
77357775
"priority": ts._priority,
77367776
"duration": duration,
7777+
"stimulus_id": f"compute-task-{time()}",
7778+
"who_has": {},
77377779
}
77387780
if ts._resource_restrictions:
77397781
msg["resource_restrictions"] = ts._resource_restrictions
@@ -7758,6 +7800,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
77587800

77597801
if ts._annotations:
77607802
msg["annotations"] = ts._annotations
7803+
7804+
assert "stimulus_id" in msg
77617805
return msg
77627806

77637807

distributed/stealing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,15 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
230230
return
231231

232232
# Victim had already started execution, reverse stealing
233-
if state in ("memory", "executing", "long-running", None):
233+
if state in (
234+
"memory",
235+
"executing",
236+
"long-running",
237+
"released",
238+
"cancelled",
239+
"resumed",
240+
None,
241+
):
234242
self.log(("already-computing", key, victim.address, thief.address))
235243
self.scheduler.check_idle_saturated(thief)
236244
self.scheduler.check_idle_saturated(victim)
@@ -256,7 +264,7 @@ async def move_task_confirm(self, key=None, worker=None, state=None):
256264
await self.scheduler.remove_worker(thief.address)
257265
self.log(("confirm", key, victim.address, thief.address))
258266
else:
259-
raise ValueError("Unexpected task state: %s" % state)
267+
raise ValueError(f"Unexpected task state: {ts}")
260268
except Exception as e:
261269
logger.exception(e)
262270
if LOG_PDB:

0 commit comments

Comments
 (0)