Skip to content

Commit 676bb18

Browse files
committed
Ensure Cancellation of distributed.comm.core.connect always raises CancelledError
1 parent fc5b460 commit 676bb18

File tree

7 files changed

+169
-81
lines changed

7 files changed

+169
-81
lines changed

distributed/comm/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from distributed.metrics import time
1919
from distributed.protocol import pickle
2020
from distributed.protocol.compression import get_default_compression
21+
from distributed.utils import ensure_cancellation
2122

2223
logger = logging.getLogger(__name__)
2324

@@ -286,8 +287,11 @@ def time_left():
286287
active_exception = None
287288
while time_left() > 0:
288289
try:
290+
task = ensure_cancellation(
291+
connector.connect(loc, deserialize=deserialize, **connection_args)
292+
)
289293
comm = await asyncio.wait_for(
290-
connector.connect(loc, deserialize=deserialize, **connection_args),
294+
task,
291295
timeout=min(intermediate_cap, time_left()),
292296
)
293297
break

distributed/core.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,7 @@ async def _connect(self, addr, timeout=None):
10861086
deserialize=self.deserialize,
10871087
**self.connection_args,
10881088
)
1089+
10891090
comm.name = "ConnectionPool"
10901091
comm._pool = weakref.ref(self)
10911092
comm.allow_offload = self.allow_offload
@@ -1099,8 +1100,6 @@ async def _connect(self, addr, timeout=None):
10991100
raise
11001101
finally:
11011102
self._connecting_count -= 1
1102-
except asyncio.CancelledError:
1103-
raise CommClosedError("ConnectionPool closing.")
11041103
finally:
11051104
self._pending_count -= 1
11061105

@@ -1121,30 +1120,15 @@ async def connect(self, addr, timeout=None):
11211120
if self.semaphore.locked():
11221121
self.collect()
11231122

1124-
# This construction is there to ensure that cancellation requests from
1125-
# the outside can be distinguished from cancellations of our own.
1126-
# Once the CommPool closes, we'll cancel the connect_attempt which will
1127-
# raise an OSError
1128-
# If the ``connect`` is cancelled from the outside, the Event.wait will
1129-
# be cancelled instead which we'll reraise as a CancelledError and allow
1130-
# it to propagate
11311123
connect_attempt = asyncio.create_task(self._connect(addr, timeout))
1132-
done = asyncio.Event()
11331124
self._connecting.add(connect_attempt)
1134-
connect_attempt.add_done_callback(lambda _: done.set())
11351125
connect_attempt.add_done_callback(self._connecting.discard)
1136-
11371126
try:
1138-
await done.wait()
1127+
return await connect_attempt
11391128
except asyncio.CancelledError:
1140-
# This is an outside cancel attempt
1141-
connect_attempt.cancel()
1142-
try:
1143-
await connect_attempt
1144-
except CommClosedError:
1145-
pass
1129+
if self.status == Status.closed:
1130+
raise CommClosedError("ConnectionPool closed.")
11461131
raise
1147-
return await connect_attempt
11481132

11491133
def reuse(self, addr, comm):
11501134
"""

distributed/tests/test_client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5743,12 +5743,13 @@ async def test_client_active_bad_port():
57435743
application = tornado.web.Application([(r"/", tornado.web.RequestHandler)])
57445744
http_server = tornado.httpserver.HTTPServer(application)
57455745
http_server.listen(8080)
5746-
with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}):
5747-
c = Client("127.0.0.1:8080", asynchronous=True)
5748-
with pytest.raises((TimeoutError, IOError)):
5749-
await c
5750-
await c._close(fast=True)
5751-
http_server.stop()
5746+
try:
5747+
with dask.config.set({"distributed.comm.timeouts.connect": "10ms"}):
5748+
with pytest.raises((TimeoutError, IOError)):
5749+
async with Client("127.0.0.1:8080", asynchronous=True) as c:
5750+
pass
5751+
finally:
5752+
http_server.stop()
57525753

57535754

57545755
@pytest.mark.parametrize("direct", [True, False])

distributed/tests/test_core.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import dask
1212

1313
from distributed.comm.core import CommClosedError
14+
from distributed.comm.tcp import TCPBackend, TCPConnector
1415
from distributed.core import (
1516
ConnectionPool,
1617
Server,
@@ -591,26 +592,50 @@ async def ping(comm, delay=0.1):
591592
await asyncio.gather(*[server.close() for server in servers])
592593

593594

594-
@gen_test()
595-
async def test_connection_pool_close_while_connecting(monkeypatch):
596-
"""
597-
Ensure a closed connection pool guarantees to have no connections left open
598-
even if it is closed mid-connecting
599-
"""
600-
from distributed.comm.registry import backends
601-
from distributed.comm.tcp import TCPBackend, TCPConnector
595+
class WrongCancelConnector(TCPConnector):
596+
async def connect(self, address, deserialize, **connection_args):
597+
try:
598+
await asyncio.sleep(10000)
599+
except asyncio.CancelledError:
600+
raise OSError("muhaha")
602601

603-
class SlowConnector(TCPConnector):
604-
async def connect(self, address, deserialize, **connection_args):
602+
603+
class WrongCancelBackend(TCPBackend):
604+
_connector_class = WrongCancelConnector
605+
606+
607+
class SlowConnector(TCPConnector):
608+
async def connect(self, address, deserialize, **connection_args):
609+
try:
605610
await asyncio.sleep(10000)
606-
return await super().connect(
607-
address, deserialize=deserialize, **connection_args
608-
)
611+
except BaseException:
612+
raise
609613

610-
class SlowBackend(TCPBackend):
611-
_connector_class = SlowConnector
612614

613-
monkeypatch.setitem(backends, "tcp", SlowBackend())
615+
class SlowBackend(TCPBackend):
616+
_connector_class = SlowConnector
617+
618+
619+
@pytest.mark.parametrize(
620+
"backend",
621+
[
622+
SlowBackend,
623+
WrongCancelBackend,
624+
],
625+
)
626+
@pytest.mark.parametrize(
627+
"closing",
628+
[
629+
True,
630+
False,
631+
],
632+
)
633+
@gen_test()
634+
async def test_connection_pool_cancellation(monkeypatch, closing, backend):
635+
# Ensure cancellation errors are properly reraised
636+
from distributed.comm.registry import backends
637+
638+
monkeypatch.setitem(backends, "tcp", backend())
614639

615640
async with Server({}) as server:
616641
await server.listen("tcp://")
@@ -623,53 +648,59 @@ async def connect_to_server():
623648

624649
# #tasks > limit
625650
tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)]
626-
627-
while not pool._connecting:
651+
# Ensure the pool is saturated and some connection attempts are pending to
652+
# connect
653+
while pool._pending_count != len(tasks):
628654
await asyncio.sleep(0.01)
629655

630-
await pool.close()
631-
for t in tasks:
632-
with pytest.raises(CommClosedError):
633-
await t
656+
if closing:
657+
await pool.close()
658+
for t in tasks:
659+
with pytest.raises(CommClosedError):
660+
await t
661+
else:
662+
for t in tasks:
663+
t.cancel()
664+
await asyncio.wait(tasks)
665+
assert all(t.cancelled() for t in tasks)
666+
634667
assert not pool.open
635668
assert not pool._n_connecting
636669

637670

638671
@gen_test()
639-
async def test_connection_pool_outside_cancellation(monkeypatch):
640-
# Ensure cancellation errors are properly reraised
641-
from distributed.comm.registry import backends
642-
from distributed.comm.tcp import TCPBackend, TCPConnector
672+
async def test_connect_properly_raising(monkeypatch):
673+
_connecting = 0
643674

644675
class SlowConnector(TCPConnector):
645676
async def connect(self, address, deserialize, **connection_args):
646-
await asyncio.sleep(10000)
647-
return await super().connect(
648-
address, deserialize=deserialize, **connection_args
649-
)
677+
try:
678+
nonlocal _connecting
679+
_connecting += 1
680+
await asyncio.sleep(10000)
681+
except BaseException:
682+
raise OSError
650683

651684
class SlowBackend(TCPBackend):
652685
_connector_class = SlowConnector
653686

687+
# Ensure cancellation errors are properly reraised
688+
from distributed.comm.registry import backends
689+
654690
monkeypatch.setitem(backends, "tcp", SlowBackend())
655691

656692
async with Server({}) as server:
657693
await server.listen("tcp://")
658-
pool = await ConnectionPool(limit=2)
659-
660-
async def connect_to_server():
661-
comm = await pool.connect(server.address)
662-
pool.reuse(server.address, comm)
663694

664695
# #tasks > limit
665-
tasks = [asyncio.create_task(connect_to_server()) for _ in range(5)]
666-
while not pool._connecting:
667-
await asyncio.sleep(0.01)
696+
tasks = [asyncio.create_task(connect(server.address)) for _ in range(5)]
697+
698+
while _connecting != len(tasks):
699+
await asyncio.sleep(0.1)
668700

669701
for t in tasks:
670702
t.cancel()
671-
672-
done, _ = await asyncio.wait(tasks)
703+
await asyncio.wait(tasks)
673704
assert all(t.cancelled() for t in tasks)
674705

675706

distributed/tests/test_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
TimeoutError,
2828
_maybe_complex,
2929
ensure_bytes,
30+
ensure_cancellation,
3031
ensure_ip,
3132
format_dashboard_link,
3233
get_ip_interface,
@@ -782,3 +783,47 @@ def __repr__(self):
782783
],
783784
}
784785
assert recursive_to_dict(info) == expect
786+
787+
788+
def test_ensure_cancellation():
789+
# Do not use gen_test to allow us to test on CancelledErrors
790+
async def _():
791+
ev = asyncio.Event()
792+
793+
async def f():
794+
await asyncio.sleep(0)
795+
ev.set()
796+
raise ValueError("foo")
797+
798+
async def g():
799+
ev.set()
800+
await asyncio.sleep(1000000)
801+
802+
task = asyncio.create_task(f())
803+
await ev.wait()
804+
await asyncio.sleep(0)
805+
with pytest.raises(ValueError, match="foo"):
806+
await task
807+
ev.clear()
808+
809+
task = asyncio.create_task(ensure_cancellation(f()))
810+
await ev.wait()
811+
await asyncio.sleep(0)
812+
task.cancel()
813+
with pytest.raises(asyncio.CancelledError):
814+
await task
815+
816+
ev.clear()
817+
task = asyncio.create_task(ensure_cancellation(g()))
818+
await ev.wait()
819+
task.cancel()
820+
with pytest.raises(asyncio.CancelledError):
821+
await task
822+
823+
async def h():
824+
await asyncio.sleep(0)
825+
return 1
826+
827+
assert await ensure_cancellation(h()) == 1
828+
829+
asyncio.run(_())

distributed/tests/test_worker.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -321,18 +321,17 @@ async def test_worker_port_range(s):
321321
pass
322322

323323

324-
@pytest.mark.slow
325-
@gen_test(timeout=60)
326-
async def test_worker_waits_for_scheduler():
327-
w = Worker("127.0.0.1:8724")
328-
try:
329-
await asyncio.wait_for(w, 3)
330-
except TimeoutError:
331-
pass
332-
else:
333-
assert False
334-
assert w.status not in (Status.closed, Status.running, Status.paused)
335-
await w.close(timeout=0.1)
324+
@pytest.mark.parametrize("connect_timeout", ["1s", "5s"])
325+
@gen_test()
326+
async def test_worker_waits_for_scheduler(connect_timeout):
327+
with dask.config.set({"distributed.comm.timeouts.connect": connect_timeout}):
328+
w = Worker("127.0.0.1:8724")
329+
330+
with pytest.raises(TimeoutError):
331+
await asyncio.wait_for(w, 3)
332+
333+
assert w.status not in (Status.closed, Status.running, Status.paused)
334+
await w.close()
336335

337336

338337
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])

distributed/utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from hashlib import md5
2828
from importlib.util import cache_from_source
2929
from time import sleep
30-
from types import ModuleType
30+
from types import CoroutineType, ModuleType
3131
from typing import Any as AnyType
32-
from typing import ClassVar
32+
from typing import ClassVar, TypeVar
3333

3434
import click
3535
import tblib.pickling_support
@@ -1621,3 +1621,27 @@ def is_python_shutting_down() -> bool:
16211621
from distributed import _python_shutting_down
16221622

16231623
return _python_shutting_down
1624+
1625+
1626+
T = TypeVar("T")
1627+
1628+
1629+
async def ensure_cancellation(coro: CoroutineType[None, None, T]) -> T:
1630+
"""Ensure that the wrapped coro will raise a CancelledError even if its
1631+
result is already set.
1632+
1633+
See https://github.com/python/cpython/issues/86296
1634+
"""
1635+
watcher = asyncio.Event()
1636+
1637+
task = asyncio.create_task(coro)
1638+
task.add_done_callback(lambda _: watcher.set())
1639+
1640+
try:
1641+
await watcher.wait()
1642+
except asyncio.CancelledError:
1643+
task.cancel()
1644+
await watcher.wait()
1645+
raise
1646+
1647+
return task.result()

0 commit comments

Comments
 (0)