Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions distributed/diagnostics/graph_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,15 @@
priority=priority,
)

def update_graph(
self, scheduler, *, dependencies=None, priority=None, tasks=None, **kwargs
):
def update_graph(self, scheduler, *, priority=None, tasks=None, **kwargs):
stack = sorted(
tasks, key=lambda k: TupleComparable(priority.get(k, 0)), reverse=True
)
while stack:
key = stack.pop()
if key in self.x or key not in scheduler.tasks:
continue
deps = dependencies.get(key, ())
deps = [ts.key for ts in scheduler.tasks[key].dependencies]

Check warning on line 57 in distributed/diagnostics/graph_layout.py

View check run for this annotation

Codecov / codecov/patch

distributed/diagnostics/graph_layout.py#L57

Added line #L57 was not covered by tests
if deps:
if not all(dep in self.y for dep in deps):
stack.append(key)
Expand Down
3 changes: 0 additions & 3 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def update_graph(
tasks: list[Key],
annotations: dict[str, dict[Key, Any]],
priority: dict[Key, tuple[int | float, ...]],
dependencies: dict[Key, set[Key]],
stimulus_id: str,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -128,8 +127,6 @@ def update_graph(
}
priority:
Task calculated priorities as assigned to the tasks.
dependencies:
A mapping that maps a key to its dependencies.
stimulus_id:
ID of the stimulus causing the graph update
**kwargs:
Expand Down
7 changes: 0 additions & 7 deletions distributed/diagnostics/tests/test_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,6 @@ def update_graph( # type: ignore
tasks,
annotations,
priority,
dependencies,
stimulus_id,
**kwargs,
) -> None:
Expand All @@ -505,7 +504,6 @@ def update_graph( # type: ignore
assert annotations == {}
assert len(priority) == 1
assert isinstance(priority["foo"], tuple)
assert dependencies == {"foo": set()}
assert stimulus_id is not None
self.success = True

Expand Down Expand Up @@ -534,7 +532,6 @@ def update_graph( # type: ignore
tasks,
annotations,
priority,
dependencies,
stimulus_id,
**kwargs,
) -> None:
Expand All @@ -552,10 +549,6 @@ def update_graph( # type: ignore
}
assert len(priority) == len(tasks), priority
assert priority["f2"][0] == -13
for k in keys:
assert k in dependencies
assert dependencies["f1"] == set()
assert dependencies["sum"] == {"f1", "f3"}
assert stimulus_id is not None

self.success = True
Expand Down
105 changes: 25 additions & 80 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@
from tornado.ioloop import IOLoop

import dask
import dask.utils
from dask._expr import LLGExpr
from dask._task_spec import DependenciesMapping, GraphNode, convert_legacy_graph
from dask._task_spec import GraphNode, convert_legacy_graph
from dask.core import istask, validate_key
from dask.typing import Key, no_default
from dask.utils import (
Expand Down Expand Up @@ -4705,7 +4704,6 @@ def _create_taskstate_from_graph(
*,
start: float,
dsk: dict[Key, T_runspec],
dependencies: dict,
keys: set[Key],
ordered: dict[Key, int],
client: str,
Expand Down Expand Up @@ -4744,14 +4742,12 @@ def _create_taskstate_from_graph(
# annotations.
computation.annotations.update(global_annotations)
(
runnable,
touched_tasks,
new_tasks,
colliding_task_count,
) = self._generate_taskstates(
keys=keys,
dsk=dsk,
dependencies=dependencies,
computation=computation,
)

Expand All @@ -4773,7 +4769,7 @@ def _create_taskstate_from_graph(
user_priority=user_priority,
fifo_timeout=fifo_timeout,
start=start,
tasks=runnable,
tasks=touched_tasks,
)

self.client_desires_keys(keys=keys, client=client)
Expand All @@ -4787,19 +4783,17 @@ def _create_taskstate_from_graph(

# Compute recommendations
recommendations: Recs = {}
priority = dict()
for ts in sorted(
runnable,
filter(
lambda ts: ts.state == "released",
map(self.tasks.__getitem__, keys),
),
key=operator.attrgetter("priority"),
reverse=True,
):
assert ts.priority # mypy
priority[ts.key] = ts.priority
assert ts.run_spec
if ts.state == "released":
recommendations[ts.key] = "waiting"
recommendations[ts.key] = "waiting"

for ts in runnable:
for ts in touched_tasks:
for dts in ts.dependencies:
if dts.exception_blame:
ts.exception_blame = dts.exception_blame
Expand All @@ -4820,7 +4814,7 @@ def _create_taskstate_from_graph(
# TaskState may have also been created by client_desires_keys or scatter,
# and only later gained a run_spec.
span_annotations = spans_ext.observe_tasks(
runnable, span_metadata=span_metadata, code=code
touched_tasks, span_metadata=span_metadata, code=code
)
# In case of TaskGroup collision, spans may have changed
# FIXME: Is this used anywhere besides tests?
Expand All @@ -4829,16 +4823,17 @@ def _create_taskstate_from_graph(
else:
annotations_for_plugin.pop("span", None)

tasks_for_plugin = [ts.key for ts in touched_tasks]
priorities_for_plugin = {ts.key: ts.priority for ts in touched_tasks}
for plugin in list(self.plugins.values()):
try:
plugin.update_graph(
self,
client=client,
tasks=[ts.key for ts in touched_tasks],
tasks=tasks_for_plugin,
keys=keys,
dependencies=dependencies,
annotations=dict(annotations_for_plugin),
priority=priority,
annotations=annotations_for_plugin,
priority=priorities_for_plugin,
stimulus_id=stimulus_id,
)
except Exception as e:
Expand All @@ -4852,42 +4847,6 @@ def _create_taskstate_from_graph(

return metrics

def _remove_done_tasks_from_dsk(
self,
dsk: dict[Key, T_runspec],
dependencies: dict[Key, set[Key]],
) -> None:
# Avoid computation that is already finished
done = set() # tasks that are already done
for k, v in dependencies.items():
if v and k in self.tasks:
ts = self.tasks[k]
if ts.state in ("memory", "erred"):
done.add(k)
if done:
dependents = dask.core.reverse_dict(dependencies)
stack = list(done)
while stack: # remove unnecessary dependencies
key = stack.pop()
try:
deps = dependencies[key]
except KeyError:
deps = {ts.key for ts in self.tasks[key].dependencies}
for dep in deps:
if dep in dependents:
child_deps = dependents[dep]
elif dep in self.tasks:
child_deps = {ts.key for ts in self.tasks[key].dependencies}
else:
child_deps = set()
if all(d in done for d in child_deps):
if dep in self.tasks and dep not in done:
done.add(dep)
stack.append(dep)
for anc in done:
dsk.pop(anc, None)
dependencies.pop(anc, None)

@log_errors
async def update_graph(
self,
Expand Down Expand Up @@ -4924,7 +4883,6 @@ async def update_graph(
raise RuntimeError(textwrap.dedent(msg)) from e
(
dsk,
dependencies,
annotations_by_type,
) = await offload(
_materialize_graph,
Expand Down Expand Up @@ -4976,12 +4934,9 @@ async def update_graph(

before = len(self.tasks)

self._remove_done_tasks_from_dsk(dsk, dependencies)

metrics = self._create_taskstate_from_graph(
dsk=dsk,
client=client,
dependencies=dependencies,
keys=set(keys),
ordered=internal_priority or {},
submitting_task=submitting_task,
Expand Down Expand Up @@ -5045,17 +5000,16 @@ def _generate_taskstates(
self,
keys: set[Key],
dsk: dict[Key, T_runspec],
dependencies: dict[Key, set[Key]],
computation: Computation,
) -> tuple:
# Get or create task states
runnable = list()
new_tasks = []
stack = list(keys)
touched_keys = set()
touched_tasks = []
tgs_with_bad_run_spec = set()
colliding_task_count = 0
collisions = set()
while stack:
k = stack.pop()
if k in touched_keys:
Expand All @@ -5078,18 +5032,13 @@ def _generate_taskstates(
elif k in dsk:
# Check dependency names.
deps_lhs = {dts.key for dts in ts.dependencies}
deps_rhs = dependencies[k]
deps_rhs = dsk[k].dependencies

# FIXME It would be a really healthy idea to change this to a hard
# failure. However, this is not possible at the moment because of
# https://github.com/dask/dask/issues/9888
if deps_lhs != deps_rhs:
# Retain old run_spec and dependencies; rerun them if necessary.
# This sweeps the issue of collision under the carpet as long as the
# old and new task produce the same output - such as in
# dask/dask#9888.
dependencies[k] = deps_lhs

collisions.add(k)
colliding_task_count += 1
if ts.group not in tgs_with_bad_run_spec:
tgs_with_bad_run_spec.add(ts.group)
Expand Down Expand Up @@ -5120,18 +5069,17 @@ def _generate_taskstates(
"two consecutive calls to `update_graph`."
)

if ts.run_spec:
runnable.append(ts)
touched_keys.add(k)
touched_tasks.append(ts)
stack.extend(dependencies.get(k, ()))
if tspec := dsk.get(k, ()):
stack.extend(tspec.dependencies)

# Add dependencies
for key, deps in dependencies.items():
for key, tspec in dsk.items():
ts = self.tasks.get(key)
if ts is None or ts.dependencies:
if ts is None or key in collisions:
continue
for dep in deps:
for dep in tspec.dependencies:
dts = self.tasks[dep]
ts.add_dependency(dts)

Expand All @@ -5141,7 +5089,7 @@ def _generate_taskstates(
len(touched_tasks),
len(keys),
)
return runnable, touched_tasks, new_tasks, colliding_task_count
return touched_tasks, new_tasks, colliding_task_count

def _apply_annotations(
self,
Expand Down Expand Up @@ -9509,7 +9457,7 @@ def transition(
def _materialize_graph(
expr: Expr,
validate: bool,
) -> tuple[dict[Key, T_runspec], dict[Key, set[Key]], dict[str, dict[Key, Any]]]:
) -> tuple[dict[Key, T_runspec], dict[str, dict[Key, Any]]]:
dsk: dict = expr.__dask_graph__()
if validate:
for k in dsk:
Expand All @@ -9520,10 +9468,7 @@ def _materialize_graph(
annotations_by_type[annotations_type].update(value)

dsk2 = convert_legacy_graph(dsk)
# FIXME: There should be no need to fully materialize and copy this but some
# sections in the scheduler are mutating it.
dependencies = {k: set(v) for k, v in DependenciesMapping(dsk2).items()}
return dsk2, dependencies, annotations_by_type
return dsk2, annotations_by_type


def _cull(dsk: dict[Key, GraphNode], keys: set[Key]) -> dict[Key, GraphNode]:
Expand Down
11 changes: 8 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3111,8 +3111,9 @@ async def test_submit_on_cancelled_future(c, s, a, b):


@pytest.mark.parametrize("validate", [True, False])
@pytest.mark.parametrize("swap_keys", [True, False])
@gen_cluster(client=True)
async def test_compute_partially_forgotten(c, s, *workers, validate):
async def test_compute_partially_forgotten(c, s, *workers, validate, swap_keys):
if not validate:
s.validate = False
# (CPython impl detail)
Expand All @@ -3127,6 +3128,8 @@ async def test_compute_partially_forgotten(c, s, *workers, validate):
# At the time of writing, it is unclear why the lost_dep_of_key is part of
# keys but this triggers an observed error
keys = key, lost_dep_of_key = list({"foo", "bar"})
if swap_keys:
keys = lost_dep_of_key, key = [key, lost_dep_of_key]

# Ordinarily this is not submitted as a graph but it could be if a persist
# was leading up to this
Expand All @@ -3135,13 +3138,15 @@ async def test_compute_partially_forgotten(c, s, *workers, validate):
# zombie task around after triggering the "lost deps" exception. That zombie
# causes the second one to trigger the transition error.
res = c.get({task.key: task}, keys, sync=False)

res = c.get({task.key: task}, keys, sync=False)
assert res[1].key == lost_dep_of_key
with pytest.raises(CancelledError, match="lost dependencies"):
await res[1].result()
with pytest.raises(CancelledError, match="lost dependencies"):
await res[0].result()

# No transition errors
while (
# This waits until update-graph is truly finished
len([msg[1]["action"] == "update-graph" for msg in s.get_events("scheduler")])
< 2
):
Expand Down
Loading
Loading