Skip to content

Commit

Permalink
Prevent infinite transition loops (dask#6318)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 11, 2022
1 parent 1937be7 commit deb8ad1
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 30 deletions.
7 changes: 7 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,13 @@ properties:
type: boolean
description: Enter Python Debugger on scheduling error

transition-counter-max:
oneOf:
- enum: [false]
- type: integer
description: Cause the scheduler or workers to break if they reach this
number of transitions

system-monitor:
type: object
description: |
Expand Down
4 changes: 4 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ distributed:
log-length: 10000 # default length of logs to keep in memory
log-format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
pdb-on-err: False # enter debug mode on scheduling error
# Cause scheduler and workers to break if they reach this many transitions.
# Used to debug infinite transition loops.
# Note: setting this will cause healthy long-running services to eventually break.
transition-counter-max: False
system-monitor:
interval: 500ms
event-loop: tornado
Expand Down
29 changes: 22 additions & 7 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ class SchedulerState:
"validate",
"workers",
"transition_counter",
"transition_counter_max",
"plugins",
"UNKNOWN_TASK_DURATION",
"MEMORY_RECENT_TO_OLD_TIME",
Expand Down Expand Up @@ -1354,6 +1355,9 @@ def __init__(
/ 2.0
)
self.transition_counter = 0
self.transition_counter_max = dask.config.get(
"distributed.admin.transition-counter-max"
)

@property
def memory(self) -> MemoryState:
Expand Down Expand Up @@ -1430,16 +1434,24 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
Scheduler.transitions : transitive version of this function
"""
try:
recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore

ts: TaskState = self.tasks.get(key) # type: ignore
if ts is None:
return recommendations, client_msgs, worker_msgs
return {}, {}, {}
start = ts._state
if start == finish:
return recommendations, client_msgs, worker_msgs
return {}, {}, {}

# Notes:
# - in case of transition through released, this counter is incremented by 2
# - this increase happens before the actual transitions, so that it can
# catch potential infinite recursions
self.transition_counter += 1
if self.validate and self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

recommendations = {} # type: ignore
worker_msgs = {} # type: ignore
client_msgs = {} # type: ignore

if self.plugins:
dependents = set(ts.dependents)
Expand All @@ -1451,7 +1463,7 @@ def _transition(self, key, finish: str, stimulus_id: str, *args, **kwargs):
recommendations, client_msgs, worker_msgs = func(
key, stimulus_id, *args, **kwargs
) # type: ignore
self.transition_counter += 1

elif "released" not in start_finish:
assert not args and not kwargs, (args, kwargs, start_finish)
a_recs: dict
Expand Down Expand Up @@ -3173,6 +3185,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
info = super()._to_dict(exclude=exclude)
extra = {
"transition_log": self.transition_log,
"transition_counter": self.transition_counter,
"log": self.log,
"tasks": self.tasks,
"task_groups": self.task_groups,
Expand Down Expand Up @@ -4496,6 +4509,8 @@ def validate_state(self, allow_overlap: bool = False) -> None:
actual_total_occupancy,
self.total_occupancy,
)
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

###################
# Manage Messages #
Expand Down
62 changes: 60 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename

from distributed import (
CancelledError,
Client,
Event,
Lock,
Expand Down Expand Up @@ -3215,11 +3216,67 @@ async def test_computations_futures(c, s, a, b):
assert "inc" in str(computation.groups)


@gen_cluster(client=True)
async def test_transition_counter(c, s, a, b):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_transition_counter(c, s, a):
assert s.transition_counter == 0
assert a.transition_counter == 0
await c.submit(inc, 1)
assert s.transition_counter > 1
assert a.transition_counter > 1


@pytest.mark.slow
@gen_cluster(client=True)
async def test_transition_counter_max_scheduler(c, s, a, b):
# This is set by @gen_cluster; it's False in production
assert s.transition_counter_max > 0
s.transition_counter_max = 1
with captured_logger("distributed.scheduler") as logger:
with pytest.raises(CancelledError):
await c.submit(inc, 2)
assert s.transition_counter > 1
with pytest.raises(AssertionError):
s.validate_state()
assert "transition_counter_max" in logger.getvalue()
# Scheduler state is corrupted. Avoid test failure on gen_cluster teardown.
s.validate = False


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_transition_counter_max_worker(c, s, a):
# This is set by @gen_cluster; it's False in production
assert s.transition_counter_max > 0
a.transition_counter_max = 1
with captured_logger("distributed.core") as logger:
fut = c.submit(inc, 2)
while True:
try:
a.validate_state()
except AssertionError:
break
await asyncio.sleep(0.01)

assert "TransitionCounterMaxExceeded" in logger.getvalue()
# Worker state is corrupted. Avoid test failure on gen_cluster teardown.
a.validate = False


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.admin.transition-counter-max": False},
)
async def test_disable_transition_counter_max(c, s, a, b):
"""Test that the cluster can run indefinitely if transition_counter_max is disabled.
This is the default outside of @gen_cluster.
"""
assert s.transition_counter_max is False
assert a.transition_counter_max is False
assert await c.submit(inc, 1) == 2
assert s.transition_counter > 1
assert a.transition_counter > 1
s.validate_state()
a.validate_state()


@gen_cluster(
Expand Down Expand Up @@ -3339,6 +3396,7 @@ async def test_Scheduler__to_dict(c, s, a):
"status",
"thread_id",
"transition_log",
"transition_counter",
"log",
"memory",
"tasks",
Expand Down
14 changes: 12 additions & 2 deletions distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ async def test_stress_steal(c, s, *workers):


@pytest.mark.slow
@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=180)
@gen_cluster(
nthreads=[("", 1)] * 10,
client=True,
timeout=180,
config={"distributed.admin.transition-counter-max": 500_000},
)
async def test_close_connections(c, s, *workers):
da = pytest.importorskip("dask.array")
x = da.random.random(size=(1000, 1000), chunks=(1000, 1))
Expand Down Expand Up @@ -291,7 +296,12 @@ async def test_no_delay_during_large_transfer(c, s, w):


@pytest.mark.slow
@gen_cluster(client=True, Worker=Nanny, nthreads=[("127.0.0.1", 2)] * 6)
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 2)] * 6,
config={"distributed.admin.transition-counter-max": 500_000},
)
async def test_chaos_rechunk(c, s, *workers):
s.allowed_failures = 10000

Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,7 @@ async def test_Worker__to_dict(c, s, a):
"busy_workers",
"log",
"stimulus_log",
"transition_counter",
"tasks",
"logs",
"config",
Expand Down
36 changes: 22 additions & 14 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,6 @@
from time import sleep
from typing import Any, Generator, Literal

from distributed.compatibility import MACOS
from distributed.scheduler import Scheduler

try:
import ssl
except ImportError:
ssl = None # type: ignore

import pytest
import yaml
from tlz import assoc, memoize, merge
Expand All @@ -43,12 +35,12 @@

import dask

from distributed import system
from distributed import Scheduler, system
from distributed import versions as version_module
from distributed.client import Client, _global_clients, default_client
from distributed.comm import Comm
from distributed.comm.tcp import TCP
from distributed.compatibility import WINDOWS
from distributed.compatibility import MACOS, WINDOWS
from distributed.config import initialize_logging
from distributed.core import (
CommClosedError,
Expand Down Expand Up @@ -79,6 +71,11 @@
)
from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker

try:
import ssl
except ImportError:
ssl = None # type: ignore

try:
import dask.array # register config
except ImportError:
Expand Down Expand Up @@ -447,8 +444,6 @@ async def background_read():

def run_scheduler(q, nputs, config, port=0, **kwargs):
with dask.config.set(config):
from distributed import Scheduler

# On Python 2.7 and Unix, fork() is used to spawn child processes,
# so avoid inheriting the parent's IO loop.
with pristine_loop() as loop:
Expand Down Expand Up @@ -999,6 +994,7 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture
worker_kwargs = merge(
{"memory_limit": system.MEMORY_LIMIT, "death_timeout": 15}, worker_kwargs
)
config = merge({"distributed.admin.transition-counter-max": 50_000}, config)

def _(func):
if not iscoroutinefunction(func):
Expand Down Expand Up @@ -1052,8 +1048,7 @@ async def coro():
task = asyncio.create_task(coro)
coro2 = asyncio.wait_for(asyncio.shield(task), timeout)
result = await coro2
if s.validate:
s.validate_state()
validate_state(s, *workers)

except asyncio.TimeoutError:
assert task
Expand All @@ -1073,6 +1068,10 @@ async def coro():
while not task.cancelled():
await asyncio.sleep(0.01)

# Hopefully, the hang has been caused by inconsistent state,
# which should be much more meaningful than the timeout
validate_state(s, *workers)

# Remove as much of the traceback as possible; it's
# uninteresting boilerplate from utils_test and asyncio and
# not from the code being tested.
Expand Down Expand Up @@ -1205,6 +1204,15 @@ async def dump_cluster_state(
print(f"Dumped cluster state to {fname}")


def validate_state(*servers: Scheduler | Worker | Nanny) -> None:
"""Run validate_state() on the Scheduler and all the Workers of the cluster.
Excludes workers wrapped by Nannies and workers manually started by the test.
"""
for s in servers:
if s.validate and hasattr(s, "validate_state"):
s.validate_state() # type: ignore


def raises(func, exc=Exception):
try:
func()
Expand Down
Loading

0 comments on commit deb8ad1

Please sign in to comment.