diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cc8f310f6d..8732e66d78 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1128,6 +1128,22 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]: """ return recursive_to_dict(self, exclude=exclude, members=True) + @property + def done(self) -> bool: + """Return True if all computations for this group have completed; False + otherwise. + + Notes + ----- + This property may transition from True to False, e.g. when a worker that + contained the only replica of a task in memory crashes and the task need to be + recomputed. + """ + return all( + count == 0 or state in {"memory", "erred", "released", "forgotten"} + for state, count in self.states.items() + ) + class TaskState: """A simple object holding information about a task. @@ -1803,16 +1819,13 @@ def _clear_task_state(self) -> None: def is_idle(self) -> bool: """Return True iff there are no tasks that haven't finished computing. - Unlike testing ``self.total_occupancy``, this property returns False if there are - long-running tasks, no-worker, or queued tasks (due to not having any workers). + Unlike testing ``self.total_occupancy``, this property returns False if there + are long-running tasks, no-worker, or queued tasks (due to not having any + workers). Not to be confused with ``idle``. """ - return all( - count == 0 or state in {"memory", "error", "released", "forgotten"} - for tg in self.task_groups.values() - for state, count in tg.states.items() - ) + return all(tg.done for tg in self.task_groups.values()) @property def total_occupancy(self) -> float: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 0de203a928..d58b256d9f 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2661,6 +2661,92 @@ async def test_task_groups_update_start_stop(c, s, a): assert t0 < tg.start < t1 < t2 < t3 < tg.stop < t4 +@gen_cluster(client=True) +async def test_task_group_done(c, s, a, b): + """TaskGroup.done is True iff all of its tasks are in memory, erred, released, or + forgotten state + """ + x0 = c.submit(inc, 1, key=("x", 0)) # memory + x1 = c.submit(lambda: 1 / 0, key=("x", 1)) # erred + x2 = c.submit(inc, 3, key=("x", 2)) + + x3 = c.submit(inc, 3, key=("x", 3)) # released + y = c.submit(inc, x3, key="y") + del x3 + + await wait([x0, x1, x2, y]) + del x2 # forgotten + await async_poll_for(lambda: str(("x", 2)) not in s.tasks, timeout=5) + + tg = s.task_groups["x"] + assert tg.states == { + "erred": 1, + "forgotten": 1, + "memory": 1, + "no-worker": 0, + "processing": 0, + "queued": 0, + "released": 1, + "waiting": 0, + } + assert tg.done + + +@gen_cluster(client=True) +async def test_task_group_not_done_waiting(c, s, a, b): + """TaskGroup.done is False if any of its tasks are in waiting state""" + ev = Event() + x = c.submit(ev.wait, key="x") + y0 = c.submit(lambda x: x, x, key=("y", 0)) + y1 = c.submit(inc, 1, key=("y", 1)) + await wait_for_state(y0.key, "waiting", s) + await y1 + + tg = s.task_groups["y"] + assert not tg.done + await ev.set() + + +@gen_cluster(client=True) +async def test_task_group_not_done_noworker(c, s, a, b): + """TaskGroup.done is False if any of its tasks are in no-worker state""" + x0 = c.submit(inc, 1, key=("x", 0), resources={"X": 1}) + x1 = c.submit(inc, 1, key=("x", 1)) + await wait_for_state(x0.key, "no-worker", s) + await x1 + + tg = s.task_groups["x"] + assert not tg.done + + +@gen_cluster( + client=True, + nthreads=[], + config={"distributed.scheduler.worker-saturation": 1.0}, + timeout=3, +) +async def test_task_group_not_done_queued(c, s): + """TaskGroup.done is False if any of its tasks are in queued state""" + futs = c.map(inc, range(4)) + await wait_for_state(futs[0].key, "queued", s) + tg = s.task_groups["inc"] + assert not tg.done + + +@gen_cluster(client=True) +async def test_task_group_not_done_processing(c, s, a, b): + """TaskGroup.done is False if any of its tasks are in processing state""" + ev = Event() + x0 = c.submit(ev.wait, key=("x", 0)) + x1 = c.submit(inc, 1, key=("x", 1)) + await wait_for_state(x0.key, "processing", s) + await x1 + + tg = s.task_groups["x"] + assert not tg.done + await ev.set() + + @gen_cluster(client=True) async def test_task_prefix(c, s, a, b): da = pytest.importorskip("dask.array")