Skip to content

Commit 3bf6103

Browse files
committed
pool: Track connections and prohibit using them after release.
Connection pool now wraps all connections in `PooledConnectionProxy` objects to raise `InterfaceError` if they are used after being released back to the pool. We also check if connection passed to `pool.release` actually belong to the pool and correctly handle multiple calls to `pool.release` with the same connection object. `PooledConnectionProxy` transparently wraps Connection instances, exposing all Connection public API. `isinstance(asyncpg.connection.Connection)` is `True` for Instances of `PooledConnectionProxy` class.
1 parent 537c8c9 commit 3bf6103

File tree

4 files changed

+217
-22
lines changed

4 files changed

+217
-22
lines changed

asyncpg/connection.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414
import urllib.parse
1515

1616
from . import cursor
17+
from . import exceptions
1718
from . import introspection
1819
from . import prepared_stmt
1920
from . import protocol
2021
from . import serverversion
2122
from . import transaction
2223

2324

24-
class Connection:
25+
class ConnectionMeta(type):
26+
27+
def __instancecheck__(cls, instance):
28+
mro = type(instance).__mro__
29+
return Connection in mro or _ConnectionProxy in mro
30+
31+
32+
class Connection(metaclass=ConnectionMeta):
2533
"""A representation of a database session.
2634
2735
Connections are created by calling :func:`~asyncpg.connection.connect`.
@@ -32,7 +40,7 @@ class Connection:
3240
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
3341
'_addr', '_opts', '_command_timeout', '_listeners',
3442
'_server_version', '_server_caps', '_intro_query',
35-
'_reset_query')
43+
'_reset_query', '_proxy')
3644

3745
def __init__(self, protocol, transport, loop, addr, opts, *,
3846
statement_cache_size, command_timeout):
@@ -70,6 +78,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
7078
self._intro_query = introspection.INTRO_LOOKUP_TYPES
7179

7280
self._reset_query = None
81+
self._proxy = None
7382

7483
async def add_listener(self, channel, callback):
7584
"""Add a listener for Postgres notifications.
@@ -478,9 +487,18 @@ def _notify(self, pid, channel, payload):
478487
if channel not in self._listeners:
479488
return
480489

490+
if self._proxy is None:
491+
con_ref = self
492+
else:
493+
# `_proxy` is not None when the connection is a member
494+
# of a connection pool. Which means that the user is working
495+
# with a PooledConnectionProxy instance, and expects to see it
496+
# (and not the actual Connection) in their event callbacks.
497+
con_ref = self._proxy
498+
481499
for cb in self._listeners[channel]:
482500
try:
483-
cb(self, pid, channel, payload)
501+
cb(con_ref, pid, channel, payload)
484502
except Exception as ex:
485503
self._loop.call_exception_handler({
486504
'message': 'Unhandled exception in asyncpg notification '
@@ -517,6 +535,14 @@ def _get_reset_query(self):
517535

518536
return _reset_query
519537

538+
def _set_proxy(self, proxy):
539+
if self._proxy is not None and proxy is not None:
540+
# Should not happen unless there is a bug in `Pool`.
541+
raise exceptions.InterfaceError(
542+
'internal asyncpg error: connection is already proxied')
543+
544+
self._proxy = proxy
545+
520546

521547
async def connect(dsn=None, *,
522548
host=None, port=None,
@@ -526,7 +552,7 @@ async def connect(dsn=None, *,
526552
timeout=60,
527553
statement_cache_size=100,
528554
command_timeout=None,
529-
connection_class=Connection,
555+
__connection_class__=Connection,
530556
**opts):
531557
"""A coroutine to establish a connection to a PostgreSQL server.
532558
@@ -564,11 +590,7 @@ async def connect(dsn=None, *,
564590
:param float command_timeout: the default timeout for operations on
565591
this connection (the default is no timeout).
566592
567-
:param builtins.type connection_class: A class used to represent
568-
the connection.
569-
Defaults to :class:`~asyncpg.connection.Connection`.
570-
571-
:return: A *connection_class* instance.
593+
:return: A :class:`~asyncpg.connection.Connection` instance.
572594
573595
Example:
574596
@@ -582,10 +604,6 @@ async def connect(dsn=None, *,
582604
... print(types)
583605
>>> asyncio.get_event_loop().run_until_complete(run())
584606
[<Record typname='bool' typnamespace=11 ...
585-
586-
587-
.. versionadded:: 0.10.0
588-
*connection_class* argument.
589607
"""
590608
if loop is None:
591609
loop = asyncio.get_event_loop()
@@ -629,13 +647,18 @@ async def connect(dsn=None, *,
629647
tr.close()
630648
raise
631649

632-
con = connection_class(pr, tr, loop, addr, opts,
633-
statement_cache_size=statement_cache_size,
634-
command_timeout=command_timeout)
650+
con = __connection_class__(pr, tr, loop, addr, opts,
651+
statement_cache_size=statement_cache_size,
652+
command_timeout=command_timeout)
635653
pr.set_connection(con)
636654
return con
637655

638656

657+
class _ConnectionProxy:
658+
# Base class to enable `isinstance(Connection)` check.
659+
__slots__ = ()
660+
661+
639662
def _parse_connect_params(*, dsn, host, port, user,
640663
password, database, opts):
641664

asyncpg/pool.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,80 @@
66

77

88
import asyncio
9+
import functools
910

1011
from . import connection
1112
from . import exceptions
1213

1314

15+
class PooledConnectionProxyMeta(type):
16+
17+
def __new__(mcls, name, bases, dct, *, wrap=False):
18+
if wrap:
19+
def get_wrapper(methname):
20+
meth = getattr(connection.Connection, methname)
21+
22+
def wrapper(self, *args, **kwargs):
23+
return self._dispatch(meth, args, kwargs)
24+
25+
return wrapper
26+
27+
for attrname in dir(connection.Connection):
28+
if attrname.startswith('_') or attrname in dct:
29+
continue
30+
wrapper = get_wrapper(attrname)
31+
wrapper = functools.update_wrapper(
32+
wrapper, getattr(connection.Connection, attrname))
33+
dct[attrname] = wrapper
34+
35+
if '__doc__' not in dct:
36+
dct['__doc__'] = connection.Connection.__doc__
37+
38+
return super().__new__(mcls, name, bases, dct)
39+
40+
def __init__(cls, name, bases, dct, *, wrap=False):
41+
# Needed for Python 3.5 to handle `wrap` class keyword argument.
42+
super().__init__(name, bases, dct)
43+
44+
45+
class PooledConnectionProxy(connection._ConnectionProxy,
46+
metaclass=PooledConnectionProxyMeta,
47+
wrap=True):
48+
49+
__slots__ = ('_con', '_owner')
50+
51+
def __init__(self, owner: 'Pool', con: connection.Connection):
52+
self._con = con
53+
self._owner = owner
54+
con._set_proxy(self)
55+
56+
def _unwrap(self) -> connection.Connection:
57+
if self._con is None:
58+
raise exceptions.InterfaceError(
59+
'internal asyncpg error: cannot unwrap pooled connection')
60+
61+
con, self._con = self._con, None
62+
con._set_proxy(None)
63+
return con
64+
65+
def _dispatch(self, meth, args, kwargs):
66+
if self._con is None:
67+
raise exceptions.InterfaceError(
68+
'cannot call Connection.{}(): '
69+
'connection has been released back to the pool'.format(
70+
meth.__name__))
71+
72+
return meth(self._con, *args, **kwargs)
73+
74+
def __repr__(self):
75+
if self._con is None:
76+
return '<{classname} [released] {id:#x}>'.format(
77+
classname=self.__class__.__name__, id=id(self))
78+
else:
79+
return '<{classname} {con!r} {id:#x}>'.format(
80+
classname=self.__class__.__name__, con=self._con, id=id(self))
81+
82+
1483
class Pool:
1584
"""A connection pool.
1685
@@ -168,6 +237,8 @@ async def _acquire_impl(self):
168237
else:
169238
con = await self._queue.get()
170239

240+
con = PooledConnectionProxy(self, con)
241+
171242
if self._setup is not None:
172243
try:
173244
await self._setup(con)
@@ -179,6 +250,20 @@ async def _acquire_impl(self):
179250

180251
async def release(self, connection):
181252
"""Release a database connection back to the pool."""
253+
254+
if (connection.__class__ is not PooledConnectionProxy or
255+
connection._owner is not self):
256+
raise exceptions.InterfaceError(
257+
'Pool.release() received invalid connection: '
258+
'{connection!r} is not a member of this pool'.format(
259+
connection=connection))
260+
261+
if connection._con is None:
262+
# Already released, do nothing.
263+
return
264+
265+
connection = connection._unwrap()
266+
182267
# Use asyncio.shield() to guarantee that task cancellation
183268
# does not prevent the connection from being returned to the
184269
# pool properly.
@@ -325,6 +410,10 @@ def create_pool(dsn=None, *,
325410
:param loop: An asyncio event loop instance. If ``None``, the default
326411
event loop will be used.
327412
:return: An instance of :class:`~asyncpg.pool.Pool`.
413+
414+
.. versionchanged:: 0.10.0
415+
An :exc:`~asyncpg.exceptions.InterfaceError` will be raised on any
416+
attempted operation on a released connection.
328417
"""
329418
return Pool(dsn,
330419
min_size=min_size, max_size=max_size,

tests/test_connect.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import asyncpg
1515
from asyncpg import _testbase as tb
16-
from asyncpg.connection import _parse_connect_params
16+
from asyncpg import connection
1717
from asyncpg.serverversion import split_server_version_string
1818

1919
_system = platform.uname().system
@@ -355,7 +355,7 @@ def run_testcase(self, testcase):
355355
if expected_error:
356356
es.enter_context(self.assertRaisesRegex(*expected_error))
357357

358-
result = _parse_connect_params(
358+
result = connection._parse_connect_params(
359359
dsn=dsn, host=host, port=port, user=user, password=password,
360360
database=database, opts=opts)
361361

@@ -411,3 +411,11 @@ def test_test_connect_params_run_testcase(self):
411411
def test_connect_params(self):
412412
for testcase in self.TESTS:
413413
self.run_testcase(testcase)
414+
415+
416+
class TestConnection(tb.ConnectedTestCase):
417+
418+
async def test_connection_isinstance(self):
419+
self.assertTrue(isinstance(self.con, connection.Connection))
420+
self.assertTrue(isinstance(self.con, object))
421+
self.assertFalse(isinstance(self.con, list))

0 commit comments

Comments
 (0)