Skip to content

Commit 2931a1a

Browse files
committed
fix(pool): manage CancelledError in some exception handling path
If a CancelledError was raised during check the connection would have been lost. The exception would have bubbled up but likely users are using some framework swallowing it because nobody reporting the "lost connections" issue actually reported the CancelledError. Close #1123 Close #1208
1 parent 4f6d792 commit 2931a1a

File tree

5 files changed

+78
-16
lines changed

5 files changed

+78
-16
lines changed

docs/news_pool.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@
77
``psycopg_pool`` release notes
88
==============================
99

10+
Future releases
11+
---------------
12+
13+
psycopg_pool 3.2.8 (unreleased)
14+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
15+
16+
- Don't lose connections if a `~asyncio.CancelledError` is raised in a check
17+
(:tickets:`#1123, #1208`)
18+
19+
1020
Current release
1121
---------------
1222

psycopg_pool/psycopg_pool/pool.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from time import monotonic
1616
from types import TracebackType
1717
from typing import Any, Generic, Iterator, cast
18+
from asyncio import CancelledError
1819
from weakref import ref
1920
from contextlib import contextmanager
2021

@@ -209,7 +210,7 @@ def _getconn_with_check_loop(self, deadline: float) -> CT:
209210
conn = self._getconn_unchecked(deadline - monotonic())
210211
try:
211212
self._check_connection(conn)
212-
except Exception:
213+
except (Exception, CancelledError):
213214
self._putconn(conn, from_getconn=True)
214215
else:
215216
logger.info("connection given by %r", self.name)
@@ -247,7 +248,7 @@ def _getconn_unchecked(self, timeout: float) -> CT:
247248
if not conn:
248249
try:
249250
conn = pos.wait(timeout=timeout)
250-
except Exception:
251+
except BaseException:
251252
self._stats[self._REQUESTS_ERRORS] += 1
252253
raise
253254
finally:
@@ -283,7 +284,7 @@ def _check_connection(self, conn: CT) -> None:
283284
return
284285
try:
285286
self._check(conn)
286-
except Exception as e:
287+
except BaseException as e:
287288
logger.info("connection failed check: %s", e)
288289
raise
289290

@@ -515,7 +516,7 @@ def check(self) -> None:
515516
# Check for broken connections
516517
try:
517518
self.check_connection(conn)
518-
except Exception:
519+
except (Exception, CancelledError):
519520
self._stats[self._CONNECTIONS_LOST] += 1
520521
logger.warning("discarding broken connection: %s", conn)
521522
self.run_task(AddConnection(self))
@@ -576,7 +577,7 @@ def worker(cls, q: Queue[MaintenanceTask]) -> None:
576577
# Run the task. Make sure don't die in the attempt.
577578
try:
578579
task.run()
579-
except Exception as ex:
580+
except (Exception, CancelledError) as ex:
580581
logger.warning(
581582
"task run %s failed: %s: %s", task, ex.__class__.__name__, ex
582583
)
@@ -591,7 +592,7 @@ def _connect(self, timeout: float | None = None) -> CT:
591592
t0 = monotonic()
592593
try:
593594
conn = self.connection_class.connect(self.conninfo, **kwargs)
594-
except Exception:
595+
except (Exception, CancelledError):
595596
self._stats[self._CONNECTIONS_ERRORS] += 1
596597
raise
597598
else:
@@ -628,7 +629,7 @@ def _add_connection(
628629

629630
try:
630631
conn = self._connect()
631-
except Exception as ex:
632+
except (Exception, CancelledError) as ex:
632633
logger.warning("error connecting in %r: %s", self.name, ex)
633634
if attempt.time_to_give_up(now):
634635
logger.warning(
@@ -774,7 +775,7 @@ def _reset_connection(self, conn: CT) -> None:
774775
raise e.ProgrammingError(
775776
f"connection left in status {sname} by reset function {self._reset}: discarded"
776777
)
777-
except Exception as ex:
778+
except (Exception, CancelledError) as ex:
778779
logger.warning("error resetting connection: %s", ex)
779780
self._close_connection(conn)
780781

psycopg_pool/psycopg_pool/pool_async.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from time import monotonic
1313
from types import TracebackType
1414
from typing import Any, AsyncIterator, Generic, cast
15+
from asyncio import CancelledError
1516
from weakref import ref
1617
from contextlib import asynccontextmanager
1718

@@ -237,7 +238,7 @@ async def _getconn_with_check_loop(self, deadline: float) -> ACT:
237238
conn = await self._getconn_unchecked(deadline - monotonic())
238239
try:
239240
await self._check_connection(conn)
240-
except Exception:
241+
except (Exception, CancelledError):
241242
await self._putconn(conn, from_getconn=True)
242243
else:
243244
logger.info("connection given by %r", self.name)
@@ -275,7 +276,7 @@ async def _getconn_unchecked(self, timeout: float) -> ACT:
275276
if not conn:
276277
try:
277278
conn = await pos.wait(timeout=timeout)
278-
except Exception:
279+
except BaseException:
279280
self._stats[self._REQUESTS_ERRORS] += 1
280281
raise
281282
finally:
@@ -312,7 +313,7 @@ async def _check_connection(self, conn: ACT) -> None:
312313
return
313314
try:
314315
await self._check(conn)
315-
except Exception as e:
316+
except BaseException as e:
316317
logger.info("connection failed check: %s", e)
317318
raise
318319

@@ -551,7 +552,7 @@ async def check(self) -> None:
551552
# Check for broken connections
552553
try:
553554
await self.check_connection(conn)
554-
except Exception:
555+
except (Exception, CancelledError):
555556
self._stats[self._CONNECTIONS_LOST] += 1
556557
logger.warning("discarding broken connection: %s", conn)
557558
self.run_task(AddConnection(self))
@@ -630,7 +631,7 @@ async def worker(cls, q: AQueue[MaintenanceTask]) -> None:
630631
# Run the task. Make sure don't die in the attempt.
631632
try:
632633
await task.run()
633-
except Exception as ex:
634+
except (Exception, CancelledError) as ex:
634635
logger.warning(
635636
"task run %s failed: %s: %s", task, ex.__class__.__name__, ex
636637
)
@@ -645,7 +646,7 @@ async def _connect(self, timeout: float | None = None) -> ACT:
645646
t0 = monotonic()
646647
try:
647648
conn = await self.connection_class.connect(self.conninfo, **kwargs)
648-
except Exception:
649+
except (Exception, CancelledError):
649650
self._stats[self._CONNECTIONS_ERRORS] += 1
650651
raise
651652
else:
@@ -683,7 +684,7 @@ async def _add_connection(
683684

684685
try:
685686
conn = await self._connect()
686-
except Exception as ex:
687+
except (Exception, CancelledError) as ex:
687688
logger.warning("error connecting in %r: %s", self.name, ex)
688689
if attempt.time_to_give_up(now):
689690
logger.warning(
@@ -835,7 +836,7 @@ async def _reset_connection(self, conn: ACT) -> None:
835836
f"connection left in status {sname} by reset function"
836837
f" {self._reset}: discarded"
837838
)
838-
except Exception as ex:
839+
except (Exception, CancelledError) as ex:
839840
logger.warning("error resetting connection: %s", ex)
840841
await self._close_connection(conn)
841842

tests/pool/test_pool_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77
from time import time
88
from typing import Any
9+
from asyncio import CancelledError
910

1011
import pytest
1112

@@ -694,6 +695,29 @@ def worker(i):
694695
assert cur.fetchone() == (1,)
695696

696697

698+
@skip_sync
699+
def test_cancel_on_check(pool_cls, dsn):
700+
do_cancel = True
701+
702+
def check(conn):
703+
nonlocal do_cancel
704+
if do_cancel:
705+
do_cancel = False
706+
raise CancelledError()
707+
708+
pool_cls.check_connection(conn)
709+
710+
with pool_cls(dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0) as p:
711+
try:
712+
with p.connection() as conn:
713+
conn.execute("select 1")
714+
except CancelledError:
715+
pass
716+
717+
with p.connection() as conn:
718+
conn.execute("select 1")
719+
720+
697721
def min_size(pool_cls, num=1):
698722
"""Return the minimum min_size supported by the pool class."""
699723
if pool_cls is pool.ConnectionPool:

tests/pool/test_pool_common_async.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
from time import time
55
from typing import Any
6+
from asyncio import CancelledError
67

78
import pytest
89

@@ -705,6 +706,31 @@ async def worker(i):
705706
assert await cur.fetchone() == (1,)
706707

707708

709+
@skip_sync
710+
async def test_cancel_on_check(pool_cls, dsn):
711+
do_cancel = True
712+
713+
async def check(conn):
714+
nonlocal do_cancel
715+
if do_cancel:
716+
do_cancel = False
717+
raise CancelledError()
718+
719+
await pool_cls.check_connection(conn)
720+
721+
async with pool_cls(
722+
dsn, min_size=min_size(pool_cls, 1), check=check, timeout=1.0
723+
) as p:
724+
try:
725+
async with p.connection() as conn:
726+
await conn.execute("select 1")
727+
except CancelledError:
728+
pass
729+
730+
async with p.connection() as conn:
731+
await conn.execute("select 1")
732+
733+
708734
def min_size(pool_cls, num=1):
709735
"""Return the minimum min_size supported by the pool class."""
710736
if pool_cls is pool.AsyncConnectionPool:

0 commit comments

Comments
 (0)