Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 119 additions & 78 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import gc
import json
import logging
import operator
Expand Down Expand Up @@ -2458,134 +2457,176 @@ def test_memorystate_adds_up(process, unmanaged_old, managed, managed_spilled):
assert m.optimistic + m.unmanaged_recent == m.process


_test_leak = []


def leaking(out_mib, leak_mib, sleep_time):
if leak_mib:
global __test_leak
__test_leak = "x" * (leak_mib * 2 ** 20)
out = "x" * (out_mib * 2 ** 20)
_test_leak.append("x" * (leak_mib * 2 ** 20))
sleep(sleep_time)
return out


def clear_leak():
global __test_leak
del __test_leak
gc.collect()
_test_leak.clear()


async def assert_memory(scheduler_or_workerstate, attr: str, min_, max_, timeout=10):
async def assert_memory(
scheduler_or_workerstate,
attr: str,
min_mib: float,
max_mib: float,
*,
timeout: float = 10,
) -> None:
t0 = time()
while True:
minfo = scheduler_or_workerstate.memory
nmib = getattr(minfo, attr) / 2 ** 20
if min_ <= nmib <= max_:
if min_mib <= nmib <= max_mib:
return
if time() - t0 > timeout:
raise TimeoutError(
f"Expected {min_} MiB <= {attr} <= {max_} MiB; got:\n{minfo!r}"
raise AssertionError(
f"Expected {min_mib} MiB <= {attr} <= {max_mib} MiB; got:\n{minfo!r}"
)
await asyncio.sleep(0.1)
await asyncio.sleep(0.01)


# ~31s runtime, or distributed.worker.memory.recent-to-old-time + 1s.
# On Windows, it can take ~65s due to worker memory needing to stabilize first.
@pytest.mark.slow
@pytest.mark.flaky(condition=LINUX, reason="see comments", reruns=10, reruns_delay=5)
@gen_cluster(
client=True, Worker=Nanny, worker_kwargs={"memory_limit": "500 MiB"}, timeout=120
client=True,
Worker=Nanny,
config={
"distributed.worker.memory.recent-to-old-time": "4s",
"distributed.worker.memory.spill": 0.7,
},
worker_kwargs={
"heartbeat_interval": "20ms",
"memory_limit": "700 MiB",
},
)
async def test_memory(c, s, *_):
async def test_memory(c, s, *nannies):
# WorkerState objects, as opposed to the Nanny objects passed by gen_cluster
a, b = s.workers.values()

def print_memory_info(msg: str) -> None:
print(f"==== {msg} ====")
print(f"---- a ----\n{a.memory}")
print(f"---- b ----\n{b.memory}")
print(f"---- s ----\n{s.memory}")

s_m0 = s.memory
assert s_m0.process == a.memory.process + b.memory.process
assert s_m0.managed == 0
assert a.memory.managed == 0
assert b.memory.managed == 0

# When a worker first goes online, its RAM is immediately counted as unmanaged_old.
# On Windows, however, there is somehow enough time between the worker start and
# this line for 2 heartbeats and the memory keeps growing substantially for a while.
# Sometimes there is a single heartbeat but on the consecutive test we observe
# a large unexplained increase in unmanaged_recent memory.
# Wait for the situation to stabilize.
if WINDOWS:
await asyncio.sleep(10)
initial_timeout = 40
else:
initial_timeout = 0
# Trigger potential imports inside WorkerPlugin.transition
await c.submit(inc, 0, workers=[a.address])
await c.submit(inc, 1, workers=[b.address])
# Wait for the memory readings to stabilize after workers go online
await asyncio.sleep(2)
await asyncio.gather(
assert_memory(a, "unmanaged_recent", 0, 5, timeout=10),
assert_memory(b, "unmanaged_recent", 0, 5, timeout=10),
assert_memory(s, "unmanaged_recent", 0, 10, timeout=10.1),
)

await assert_memory(s, "unmanaged_recent", 0, 40, timeout=initial_timeout)
await assert_memory(a, "unmanaged_recent", 0, 20, timeout=initial_timeout)
await assert_memory(b, "unmanaged_recent", 0, 20, timeout=initial_timeout)
print()
print_memory_info("Starting memory")

f1 = c.submit(leaking, 100, 50, 10, pure=False, workers=[a.name])
f2 = c.submit(leaking, 100, 50, 10, pure=False, workers=[b.name])
await assert_memory(s, "unmanaged_recent", 300, 380)
await assert_memory(a, "unmanaged_recent", 150, 190)
await assert_memory(b, "unmanaged_recent", 150, 190)
# 50 MiB heap + 100 MiB leak
# Note that runtime=2s is less than recent-to-old-time=4s
f1 = c.submit(leaking, 50, 100, 2, key="f1", workers=[a.name])
f2 = c.submit(leaking, 50, 100, 2, key="f2", workers=[b.name])

await asyncio.gather(
assert_memory(a, "unmanaged_recent", 150, 170, timeout=1.8),
assert_memory(b, "unmanaged_recent", 150, 170, timeout=1.8),
assert_memory(s, "unmanaged_recent", 300, 340, timeout=1.9),
)
await wait([f1, f2])

# On each worker, we now have 100 MiB managed + 50 MiB fresh leak
await assert_memory(s, "managed_in_memory", 200, 201)
await assert_memory(a, "managed_in_memory", 100, 101)
await assert_memory(b, "managed_in_memory", 100, 101)
await assert_memory(s, "unmanaged_recent", 100, 180)
await assert_memory(a, "unmanaged_recent", 50, 90)
await assert_memory(b, "unmanaged_recent", 50, 90)

# Force the output of f1 and f2 to spill to disk.
# With spill=0.7 and memory_limit=500 MiB, we'll start spilling at 350 MiB process
# memory per worker, or up to 20 iterations of the below depending on how much RAM
# the interpreter is using.
more_futs = []
while not s.memory.managed_spilled:
if a.memory.process < 0.7 * 500 * 2 ** 20:
more_futs.append(c.submit(leaking, 10, 0, 0, pure=False, workers=[a.name]))
if b.memory.process < 0.7 * 500 * 2 ** 20:
more_futs.append(c.submit(leaking, 10, 0, 0, pure=False, workers=[b.name]))
await wait(more_futs)
await asyncio.sleep(1)
# On each worker, we now have 50 MiB managed + 100 MiB fresh leak
await asyncio.gather(
assert_memory(a, "managed_in_memory", 50, 51, timeout=0),
assert_memory(b, "managed_in_memory", 50, 51, timeout=0),
assert_memory(s, "managed_in_memory", 100, 101, timeout=0),
assert_memory(a, "unmanaged_recent", 100, 120, timeout=0),
assert_memory(b, "unmanaged_recent", 100, 120, timeout=0),
assert_memory(s, "unmanaged_recent", 200, 240, timeout=0),
)

# Wait for the spilling to finish. Note that this does not make the test take
# longer as we're waiting for recent-to-old-time anyway.
await asyncio.sleep(10)
# Force the output of f1 and f2 to spill to disk
print_memory_info("Before spill")
a_leak = round(700 * 0.7 - a.memory.process / 2 ** 20)
b_leak = round(700 * 0.7 - b.memory.process / 2 ** 20)
assert a_leak > 50 and b_leak > 50
a_leak += 10
b_leak += 10
print(f"Leaking additional memory: a_leak={a_leak}; b_leak={b_leak}")
await wait(
[
c.submit(leaking, 0, a_leak, 0, pure=False, workers=[a.name]),
c.submit(leaking, 0, b_leak, 0, pure=False, workers=[b.name]),
]
)

# Timeout needs to be enough to spill 100 MiB to disk
await asyncio.gather(
assert_memory(a, "managed_spilled", 50, 51, timeout=10),
assert_memory(b, "managed_spilled", 50, 51, timeout=10),
assert_memory(s, "managed_spilled", 100, 101, timeout=10.1),
)
# FIXME on Windows and MacOS we occasionally observe managed_in_memory = 49 bytes
await asyncio.gather(
assert_memory(a, "managed_in_memory", 0, 0.1, timeout=0),
assert_memory(b, "managed_in_memory", 0, 0.1, timeout=0),
assert_memory(s, "managed_in_memory", 0, 0.1, timeout=0),
)

print_memory_info("After spill")

# Delete spilled keys
prev = s.memory
del f1
del f2
await assert_memory(s, "managed_spilled", 0, prev.managed_spilled / 2 ** 20 - 19)

# Empty the cluster, with the exception of leaked memory
del more_futs
await assert_memory(s, "managed", 0, 0)
await asyncio.gather(
assert_memory(a, "managed_spilled", 0, 0, timeout=3),
assert_memory(b, "managed_spilled", 0, 0, timeout=3),
assert_memory(s, "managed_spilled", 0, 0, timeout=3.1),
)

orig_unmanaged = s_m0.unmanaged / 2 ** 20
orig_old = s_m0.unmanaged_old / 2 ** 20
print_memory_info("After clearing spilled keys")

# Wait until 30s have passed since the spill to observe unmanaged_recent
# Wait until 4s have passed since the spill to observe unmanaged_recent
# transition into unmanaged_old
await c.run(gc.collect)
await assert_memory(s, "unmanaged_recent", 0, 90, timeout=40)
await assert_memory(s, "unmanaged_old", orig_old + 90, 9999, timeout=40)
await asyncio.gather(
assert_memory(a, "unmanaged_recent", 0, 5, timeout=4.5),
assert_memory(b, "unmanaged_recent", 0, 5, timeout=4.5),
assert_memory(s, "unmanaged_recent", 0, 10, timeout=4.6),
)

# When the leaked memory is cleared, unmanaged and unmanaged_old drop.
# On MacOS and Windows, the process memory of the Python interpreter does not shrink
# as fast as on Linux. Note that this behaviour is heavily impacted by OS tweaks,
# meaning that what you observe on your local host may behave differently on CI.
# Even on Linux, this occasionally glitches - hence why there is a flaky marker on
# this test.
if not LINUX:
return

orig_unmanaged = s.memory.unmanaged / 2 ** 20
orig_old = s.memory.unmanaged_old / 2 ** 20
print_memory_info("Before clearing memory leak")

prev_unmanaged_a = a.memory.unmanaged / 2 ** 20
prev_unmanaged_b = b.memory.unmanaged / 2 ** 20
await c.run(clear_leak)
await assert_memory(s, "unmanaged", 0, orig_unmanaged - 60)
await assert_memory(s, "unmanaged_old", 0, orig_old - 60)
await assert_memory(s, "unmanaged_recent", 0, 90)

await asyncio.gather(
assert_memory(a, "unmanaged", 0, prev_unmanaged_a - 50, timeout=10),
assert_memory(b, "unmanaged", 0, prev_unmanaged_b - 50, timeout=10),
)
await asyncio.gather(
assert_memory(a, "unmanaged_recent", 0, 5, timeout=0),
assert_memory(b, "unmanaged_recent", 0, 5, timeout=0),
)


@gen_cluster(client=True, worker_kwargs={"memory_limit": 0})
Expand Down
6 changes: 5 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ class Worker(ServerNode):
name: Any
scheduler_delay: float
stream_comms: dict[str, BatchedSend]
heartbeat_interval: float
heartbeat_active: bool
_ipython_kernel: Any | None = None
services: dict[str, Any] = {}
Expand Down Expand Up @@ -572,6 +573,7 @@ def __init__(
preload_argv: list[str] | list[list[str]] | None = None,
security: Security | dict[str, Any] | None = None,
contact_address: str | None = None,
heartbeat_interval: Any = "1s",
memory_monitor_interval: Any = "200ms",
memory_target_fraction: float | Literal[False] | None = None,
memory_spill_fraction: float | Literal[False] | None = None,
Expand Down Expand Up @@ -947,8 +949,10 @@ def __init__(
"worker": self,
}

pc = PeriodicCallback(self.heartbeat, 1000)
self.heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms")
pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000)
self.periodic_callbacks["heartbeat"] = pc

pc = PeriodicCallback(
lambda: self.batched_stream.send({"op": "keep-alive"}), 60000
)
Expand Down