Skip to content

Commit

Permalink
Homogeneously schedule P2P's unpack tasks (#8873)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Sep 13, 2024
1 parent ec3f4ec commit 4f3ac26
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 65 deletions.
18 changes: 5 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,14 +1425,8 @@ class TaskState:
#: be rejected.
run_id: int | None

#: Whether to consider this task rootish in the context of task queueing
#: True
#: Always consider this task rootish
#: False
#: Never consider this task rootish
#: None
#: Use a heuristic to determine whether this task should be considered rootish
_rootish: bool | None
#: Whether to allow queueing this task if it is rootish
_queueable: bool

#: Cached hash of :attr:`~TaskState.client_key`
_hash: int
Expand Down Expand Up @@ -1489,7 +1483,7 @@ def __init__(
self.metadata = None
self.annotations = None
self.erred_on = None
self._rootish = None
self._queueable = True
self.run_id = None
self.group = group
group.add(self)
Expand Down Expand Up @@ -2286,7 +2280,7 @@ def decide_worker_rootish_queuing_disabled(
"""
if self.validate:
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION)
assert math.isinf(self.WORKER_SATURATION) or not ts._queueable

pool = self.idle.values() if self.idle else self.running
if not pool:
Expand Down Expand Up @@ -2452,7 +2446,7 @@ def _transition_waiting_processing(self, key: Key, stimulus_id: str) -> RecsMsgs
# removed, there should only be one, which combines co-assignment and
# queuing. Eventually, special-casing root tasks might be removed entirely,
# with better heuristics.
if math.isinf(self.WORKER_SATURATION):
if math.isinf(self.WORKER_SATURATION) or not ts._queueable:
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
Expand Down Expand Up @@ -3090,8 +3084,6 @@ def is_rootish(self, ts: TaskState) -> bool:
and have few or no dependencies. Tasks may also be explicitly marked as rootish
to override this heuristic.
"""
if ts._rootish is not None:
return ts._rootish
if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
return False
tg = ts.group
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _ensure_output_tasks_are_non_rootish(self, spec: ShuffleSpec) -> None:
"""
barrier = self.scheduler.tasks[barrier_key(spec.id)]
for dependent in barrier.dependents:
dependent._rootish = False
dependent._queueable = False

@log_errors()
def _set_restriction(self, ts: TaskState, worker: str) -> None:
Expand Down
36 changes: 36 additions & 0 deletions distributed/shuffle/tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import math
import random
import warnings
from collections import defaultdict

import pytest

from distributed.diagnostics.plugin import SchedulerPlugin

np = pytest.importorskip("numpy")
da = pytest.importorskip("dask.array")

Expand Down Expand Up @@ -1488,3 +1491,36 @@ def test_calculate_prechunking_splitting(old, new, expected):
# _calculate_prechunking does not concatenate on object
actual = _calculate_prechunking(old, new, np.dtype(object), None)
assert actual == expected


@gen_cluster(client=True, nthreads=[("", 1)] * 4, config={"array.chunk-size": "1 B"})
async def test_homogeneously_schedule_unpack(c, s, *ws):
class SchedulingTrackerPlugin(SchedulerPlugin):
async def start(self, scheduler):
self.scheduler = scheduler
self.counts = defaultdict(int)
self.seen = set()

def transition(self, key, start, finish, *args, stimulus_id, **kwargs):
if key in self.seen:
return

if not isinstance(key, tuple) or not isinstance(key[0], str):
return

if not key[0].startswith("rechunk-p2p"):
return

if start != "waiting" or finish != "processing":
return

self.seen.add(key)
self.counts[self.scheduler.tasks[key].processing_on.address] += 1

await c.register_plugin(SchedulingTrackerPlugin(), name="tracker")
res = da.random.random((100, 100), chunks=(1, -1)).rechunk((-1, 1))
await c.compute(res)
counts = s.plugins["tracker"].counts
min_count = min(counts.values())
max_count = max(counts.values())
assert min_count >= max_count, counts
27 changes: 0 additions & 27 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2685,33 +2685,6 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
return await super().barrier(id, run_id, consistent)


@gen_cluster(client=True)
async def test_unpack_is_non_rootish(c, s, a, b):
with pytest.warns(UserWarning):
scheduler_plugin = BlockedBarrierShuffleSchedulerPlugin(s)
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-21",
dtypes={"x": float, "y": float},
freq="10 s",
)
df = df.shuffle("x")
result = c.compute(df)

await scheduler_plugin.in_barrier.wait()

unpack_tss = [ts for key, ts in s.tasks.items() if key_split(key) == UNPACK_PREFIX]
assert len(unpack_tss) == 20
assert not any(s.is_rootish(ts) for ts in unpack_tss)
del unpack_tss
scheduler_plugin.block_barrier.set()
result = await result

await assert_worker_cleanup(a)
await assert_worker_cleanup(b)
await assert_scheduler_cleanup(s)


class FlakyConnectionPool(ConnectionPool):
def __init__(self, *args, failing_connects=0, **kwargs):
self.attempts = 0
Expand Down
24 changes: 0 additions & 24 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,30 +284,6 @@ def random(**kwargs):
test_decide_worker_coschedule_order_neighbors_()


@gen_cluster(
client=True,
nthreads=[],
)
async def test_override_is_rootish(c, s):
x = c.submit(lambda x: x + 1, 1, key="x")
await async_poll_for(lambda: "x" in s.tasks, timeout=5)
ts_x = s.tasks["x"]
assert ts_x._rootish is None
assert s.is_rootish(ts_x)

ts_x._rootish = False
assert not s.is_rootish(ts_x)

y = c.submit(lambda y: y + 1, 1, key="y", workers=["not-existing"])
await async_poll_for(lambda: "y" in s.tasks, timeout=5)
ts_y = s.tasks["y"]
assert ts_y._rootish is None
assert not s.is_rootish(ts_y)

ts_y._rootish = True
assert s.is_rootish(ts_y)


@pytest.mark.skipif(
QUEUING_ON_BY_DEFAULT,
reason="Not relevant with queuing on; see https://github.com/dask/distributed/issues/7204",
Expand Down

0 comments on commit 4f3ac26

Please sign in to comment.