Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4672,9 +4672,9 @@ async def add_nanny(self, comm: Comm, address: str) -> None:
def _find_lost_dependencies(
self,
dsk: dict[Key, T_runspec],
dependencies: dict[Key, set[Key]],
keys: set[Key],
) -> set[Key]:
# FIXME: There is typically no need to walk the entire graph
lost_keys = set()
seen: set[Key] = set()
sadd = seen.add
Expand All @@ -4696,8 +4696,6 @@ def _find_lost_dependencies(
k,
d,
)
dependencies.pop(d, None)
keys.discard(k)
continue
wupdate(dsk[d].dependencies)
return lost_keys
Expand Down Expand Up @@ -4909,6 +4907,8 @@ async def update_graph(
start = time()
stimulus_id = stimulus_id or f"update-graph-{start}"
self._active_graph_updates += 1
evt_msg: dict[str, Any]

try:
logger.debug("Received new graph. Deserializing...")
try:
Expand Down Expand Up @@ -4954,20 +4954,25 @@ async def update_graph(
# has to happen in the same event loop.
# *************************************

lost_keys = self._find_lost_dependencies(dsk, dependencies, keys)

if lost_keys:
if self._find_lost_dependencies(dsk, keys):
self.report(
{
"op": "cancelled-keys",
"keys": lost_keys,
"keys": keys,
"reason": "lost dependencies",
},
client=client,
)
self.client_releases_keys(
keys=lost_keys, client=client, stimulus_id=stimulus_id
keys=keys, client=client, stimulus_id=stimulus_id
)
evt_msg = {
"action": "update-graph",
"stimulus_id": stimulus_id,
"status": "cancelled",
}
self.log_event(["scheduler", client], evt_msg)
return

before = len(self.tasks)

Expand Down
44 changes: 41 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import dask
import dask.bag as db
from dask import delayed
from dask.task_spec import Task, TaskRef
from dask.tokenize import TokenizationError, tokenize
from dask.utils import get_default_shuffle_method, parse_timedelta, tmpfile

Expand Down Expand Up @@ -3109,6 +3110,46 @@ async def test_submit_on_cancelled_future(c, s, a, b):
await c.submit(inc, x)


@pytest.mark.parametrize("validate", [True, False])
@gen_cluster(client=True)
async def test_compute_partially_forgotten(c, s, *workers, validate):
if not validate:
s.validate = False
# (CPython impl detail)
# While it is not possible to know what the iteration order of a set will
# be, it is determinisitic and only depends on the hash of the inserted
# elements. Therefore, converting the set to a list will alway yield the
# same order. We're initializing the keys in this very specific order to
# ensure that the scheduler internally arranges the keys in this way

# We'll need the list to be
# ['key', 'lost_dep_of_key']
# At the time of writing, it is unclear why the lost_dep_of_key is part of
# keys but this triggers an observed error
keys = key, lost_dep_of_key = list({"foo", "bar"})
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior of this is depending on the ordering of the keys. In particular _find_lost_dependencies is misbehaving depending on the order of this line

for k in list(keys):

With this test setup we always prepare for the case where the one key is not discarded as it should be

Copy link

Copilot AI May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This tuple assignment simultaneously defining 'key', 'lost_dep_of_key' and 'keys' can be confusing. Consider splitting the assignment into two separate statements for better clarity.

Suggested change
keys = key, lost_dep_of_key = list({"foo", "bar"})
keys = list({"foo", "bar"})
key, lost_dep_of_key = keys

Copilot uses AI. Check for mistakes.

# Ordinarily this is not submitted as a graph but it could be if a persist
# was leading up to this
task = Task(key, inc, TaskRef(lost_dep_of_key))
# Only happens if it is submitted twice. The first submission leaves a
# zombie task around after triggering the "lost deps" exception. That zombie
# causes the second one to trigger the transition error.
res = c.get({task.key: task}, keys, sync=False)

res = c.get({task.key: task}, keys, sync=False)
assert res[1].key == lost_dep_of_key
with pytest.raises(CancelledError, match="lost dependencies"):
await res[1].result()

while (
len([msg[1]["action"] == "update-graph" for msg in s.get_events("scheduler")])
< 2
):
await asyncio.sleep(0.01)
assert not s.get_events("transitions")
assert not s.tasks


@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 10,
Expand Down Expand Up @@ -8226,16 +8267,13 @@ def f(x):
[
(True, True),
(False, True),
(False, False),
],
)
def test_worker_clients_do_not_claim_ownership_of_serialize_futures(
c, do_wait, store_variable
):
da = pytest.importorskip("dask.array", exc_type=ImportError)

if not store_variable and not do_wait:
pytest.skip("This test is not making sense")
# Note: sending collections like this should be considered an anti-pattern
# but it is possible. As long as the user ensures the futures stay alive
# this is fine but the cluster will not take over this responsibility. The
Expand Down
Loading