Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ensure_communicating #6165

Merged
merged 18 commits into from
May 11, 2022
Merged
8 changes: 1 addition & 7 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ async def handle_comm(self, comm):
"Failed while closing connection to %r: %s", address, e
)

async def handle_stream(self, comm, extra=None, every_cycle=()):
async def handle_stream(self, comm, extra=None):
extra = extra or {}
logger.info("Starting established connection")

Expand Down Expand Up @@ -653,12 +653,6 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
logger.error("odd message %s", msg)
await asyncio.sleep(0)

for func in every_cycle:
if is_coroutine_function(func):
self.loop.add_callback(func)
else:
func()

except OSError:
pass
except Exception as e:
Expand Down
13 changes: 5 additions & 8 deletions distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,11 @@ async def test_worker_story_with_deps(c, s, a, b):

story = a.story("res")
assert story == []
story = b.story("res")

# Story now includes randomized stimulus_ids and timestamps.
stimulus_ids = {ev[-2] for ev in story}
# Compute dep
# Success dep
# Compute res
assert len(stimulus_ids) == 3
story = b.story("res")
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task", "task-finished"}

# This is a simple transition log
expected = [
Expand All @@ -155,8 +152,8 @@ async def test_worker_story_with_deps(c, s, a, b):
assert_story(story, expected, strict=True)

story = b.story("dep")
stimulus_ids = {ev[-2] for ev in story}
assert len(stimulus_ids) == 2, stimulus_ids
stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story}
assert stimulus_ids == {"compute-task"}
expected = [
("dep", "ensure-task-exists", "released"),
("dep", "released", "fetch", "fetch", {}),
Expand Down
45 changes: 22 additions & 23 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,13 @@ async def test_clean(c, s, a, b):

@gen_cluster(client=True)
async def test_message_breakup(c, s, a, b):
n = 100000
n = 100_000
a.target_message_size = 10 * n
b.target_message_size = 10 * n
xs = [c.submit(mul, b"%d" % i, n, workers=a.address) for i in range(30)]
y = c.submit(lambda *args: None, xs, workers=b.address)
xs = [
c.submit(mul, b"%d" % i, n, key=f"x{i}", workers=[a.address]) for i in range(30)
]
y = c.submit(lambda _: None, xs, key="y", workers=[b.address])
await y

assert 2 <= len(b.incoming_transfer_log) <= 20
Expand Down Expand Up @@ -714,27 +716,32 @@ async def test_clean_nbytes(c, s, a, b):
)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 20)
async def test_gather_many_small(c, s, a, *workers):
@pytest.mark.parametrize("as_deps", [True, False])
@gen_cluster(client=True, nthreads=[("", 1)] * 20)
async def test_gather_many_small(c, s, a, *workers, as_deps):
"""If the dependencies of a given task are very small, do not limit the
number of concurrent outgoing connections
"""
a.total_out_connections = 2
futures = await c._scatter(list(range(100)))

futures = await c.scatter(
{f"x{i}": i for i in range(100)},
workers=[w.address for w in workers],
)
assert all(w.data for w in workers)

def f(*args):
return 10

future = c.submit(f, *futures, workers=a.address)
await wait(future)
if as_deps:
future = c.submit(lambda _: None, futures, key="y", workers=[a.address])
await wait(future)
else:
s.request_acquire_replicas(a.address, list(futures), stimulus_id="test")
while len(a.data) < 100:
await asyncio.sleep(0.01)

types = list(pluck(0, a.log))
req = [i for i, t in enumerate(types) if t == "request-dep"]
recv = [i for i, t in enumerate(types) if t == "receive-dep"]
assert len(req) == len(recv) == 19
assert min(recv) > max(req)

assert a.comm_nbytes == 0


Expand Down Expand Up @@ -1424,21 +1431,13 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
(key, "released", "fetch", "fetch", {}),
("gather-dependencies", w_from.address, lambda set_: key in set_),
(key, "fetch", "flight", "flight", {}),
("request-dep", w_from.address, lambda set_: key in set_),
("receive-dep", w_from.address, lambda set_: key in set_),
(key, "put-in-memory"),
(key, "flight", "memory", "memory", {}),
],
# There may be additional ('missing', 'fetch', 'fetch') events if transfers
# are slow enough that the Active Memory Manager ends up requesting them a
# second time. Here we're asserting that no matter how slow CI is, all
# transfers will be completed within 2 seconds (hardcoded interval in
# Scheduler.retire_worker when AMM is not enabled).
strict=True,
strict=False,
)
assert key in w_to.data
# The key may or may not still be in w_from.data, depending if the AMM had the
Expand Down Expand Up @@ -3054,7 +3053,7 @@ async def test_missing_released_zombie_tasks_2(c, s, b):
await asyncio.sleep(0)

ts = b.tasks[f1.key]
assert ts.state == "fetch"
assert ts.state == "flight"

while ts.state != "missing":
# If we sleep for a longer time, the worker will spin into an
Expand Down
Loading