Skip to content

Commit

Permalink
Refactor Scheduler.is_idle
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 4, 2023
1 parent 57639c1 commit 669429d
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 7 deletions.
27 changes: 20 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,22 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
"""
return recursive_to_dict(self, exclude=exclude, members=True)

@property
def done(self) -> bool:
"""Return True if all computations for this group have completed; False
otherwise.
Notes
-----
This property may transition from True to False, e.g. when a worker that
contained the only replica of a task in memory crashes and the task need to be
recomputed.
"""
return all(
count == 0 or state in {"memory", "erred", "released", "forgotten"}
for state, count in self.states.items()
)


class TaskState:
"""A simple object holding information about a task.
Expand Down Expand Up @@ -1803,16 +1819,13 @@ def _clear_task_state(self) -> None:
def is_idle(self) -> bool:
"""Return True iff there are no tasks that haven't finished computing.
Unlike testing ``self.total_occupancy``, this property returns False if there are
long-running tasks, no-worker, or queued tasks (due to not having any workers).
Unlike testing ``self.total_occupancy``, this property returns False if there
are long-running tasks, no-worker, or queued tasks (due to not having any
workers).
Not to be confused with ``idle``.
"""
return all(
count == 0 or state in {"memory", "error", "released", "forgotten"}
for tg in self.task_groups.values()
for state, count in tg.states.items()
)
return all(tg.done for tg in self.task_groups.values())

@property
def total_occupancy(self) -> float:
Expand Down
86 changes: 86 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2661,6 +2661,92 @@ async def test_task_groups_update_start_stop(c, s, a):
assert t0 < tg.start < t1 < t2 < t3 < tg.stop < t4


@gen_cluster(client=True)
async def test_task_group_done(c, s, a, b):
"""TaskGroup.done is True iff all of its tasks are in memory, erred, released, or
forgotten state
"""
x0 = c.submit(inc, 1, key=("x", 0)) # memory
x1 = c.submit(lambda: 1 / 0, key=("x", 1)) # erred
x2 = c.submit(inc, 3, key=("x", 2))

x3 = c.submit(inc, 3, key=("x", 3)) # released
y = c.submit(inc, x3, key="y")
del x3

await wait([x0, x1, x2, y])
del x2 # forgotten
await async_poll_for(lambda: str(("x", 2)) not in s.tasks, timeout=5)

tg = s.task_groups["x"]
assert tg.states == {
"erred": 1,
"forgotten": 1,
"memory": 1,
"no-worker": 0,
"processing": 0,
"queued": 0,
"released": 1,
"waiting": 0,
}
assert tg.done


@gen_cluster(client=True)
async def test_task_group_not_done_waiting(c, s, a, b):
"""TaskGroup.done is False if any of its tasks are in waiting state"""
ev = Event()
x = c.submit(ev.wait, key="x")
y0 = c.submit(lambda x: x, x, key=("y", 0))
y1 = c.submit(inc, 1, key=("y", 1))
await wait_for_state(y0.key, "waiting", s)
await y1

tg = s.task_groups["y"]
assert not tg.done
await ev.set()


@gen_cluster(client=True)
async def test_task_group_not_done_noworker(c, s, a, b):
"""TaskGroup.done is False if any of its tasks are in no-worker state"""
x0 = c.submit(inc, 1, key=("x", 0), resources={"X": 1})
x1 = c.submit(inc, 1, key=("x", 1))
await wait_for_state(x0.key, "no-worker", s)
await x1

tg = s.task_groups["x"]
assert not tg.done


@gen_cluster(
client=True,
nthreads=[],
config={"distributed.scheduler.worker-saturation": 1.0},
timeout=3,
)
async def test_task_group_not_done_queued(c, s):
"""TaskGroup.done is False if any of its tasks are in queued state"""
futs = c.map(inc, range(4))
await wait_for_state(futs[0].key, "queued", s)
tg = s.task_groups["inc"]
assert not tg.done


@gen_cluster(client=True)
async def test_task_group_not_done_processing(c, s, a, b):
"""TaskGroup.done is False if any of its tasks are in processing state"""
ev = Event()
x0 = c.submit(ev.wait, key=("x", 0))
x1 = c.submit(inc, 1, key=("x", 1))
await wait_for_state(x0.key, "processing", s)
await x1

tg = s.task_groups["x"]
assert not tg.done
await ev.set()


@gen_cluster(client=True)
async def test_task_prefix(c, s, a, b):
da = pytest.importorskip("dask.array")
Expand Down

0 comments on commit 669429d

Please sign in to comment.