Skip to content

Commit 101a599

Browse files
committed
Make Pool.close() wait until all checked out connections are released
Currently, `pool.close()`, despite the "graceful" designation, closes all connections immediately regardless of whether they are acquired. With this change, pool will wait for connections to actually be released before closing. WARNING: This is a potentially incompatible behavior change, as sloppily written code which does not release acquired connections will now cause `pool.close()` to hang forever. Also, when `conn.close()` or `conn.terminate()` are called directly on an acquired connection, the associated pool item is released immediately. Closes: #290
1 parent cf523be commit 101a599

File tree

5 files changed

+206
-94
lines changed

5 files changed

+206
-94
lines changed

asyncpg/connection.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -995,30 +995,21 @@ async def close(self, *, timeout=None):
995995
.. versionchanged:: 0.14.0
996996
Added the *timeout* parameter.
997997
"""
998-
if self.is_closed():
999-
return
1000-
self._mark_stmts_as_closed()
1001-
self._listeners.clear()
1002-
self._log_listeners.clear()
1003-
self._aborted = True
1004998
try:
1005-
await self._protocol.close(timeout)
999+
if not self.is_closed():
1000+
await self._protocol.close(timeout)
10061001
except Exception:
10071002
# If we fail to close gracefully, abort the connection.
1008-
self._aborted = True
1009-
self._protocol.abort()
1003+
self._abort()
10101004
raise
10111005
finally:
1012-
self._clean_tasks()
1006+
self._cleanup()
10131007

10141008
def terminate(self):
10151009
"""Terminate the connection without waiting for pending data."""
1016-
self._mark_stmts_as_closed()
1017-
self._listeners.clear()
1018-
self._log_listeners.clear()
1019-
self._aborted = True
1020-
self._protocol.abort()
1021-
self._clean_tasks()
1010+
if not self.is_closed():
1011+
self._abort()
1012+
self._cleanup()
10221013

10231014
async def reset(self, *, timeout=None):
10241015
self._check_open()
@@ -1041,6 +1032,23 @@ async def reset(self, *, timeout=None):
10411032
if reset_query:
10421033
await self.execute(reset_query, timeout=timeout)
10431034

1035+
def _abort(self):
1036+
# Put the connection into the aborted state.
1037+
self._aborted = True
1038+
self._protocol.abort()
1039+
1040+
def _cleanup(self):
1041+
# Free the resources associated with this connection.
1042+
# This must be called when a connection is terminated.
1043+
self._mark_stmts_as_closed()
1044+
self._listeners.clear()
1045+
self._log_listeners.clear()
1046+
self._clean_tasks()
1047+
if self._proxy is not None:
1048+
# Connection is a member of a pool, but is getting
1049+
# aborted directly. Notify the pool about the fact.
1050+
self._proxy._holder._release_on_close()
1051+
10441052
def _clean_tasks(self):
10451053
# Wrap-up any remaining tasks associated with this connection.
10461054
if self._cancellations:

asyncpg/pool.py

Lines changed: 100 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __repr__(self):
9292

9393
class PoolConnectionHolder:
9494

95-
__slots__ = ('_con', '_pool', '_loop',
95+
__slots__ = ('_con', '_pool', '_loop', '_proxy',
9696
'_connect_args', '_connect_kwargs',
9797
'_max_queries', '_setup', '_init',
9898
'_max_inactive_time', '_in_use',
@@ -103,6 +103,7 @@ def __init__(self, pool, *, connect_args, connect_kwargs,
103103

104104
self._pool = pool
105105
self._con = None
106+
self._proxy = None
106107

107108
self._connect_args = connect_args
108109
self._connect_kwargs = connect_kwargs
@@ -111,7 +112,7 @@ def __init__(self, pool, *, connect_args, connect_kwargs,
111112
self._setup = setup
112113
self._init = init
113114
self._inactive_callback = None
114-
self._in_use = False
115+
self._in_use = None # type: asyncio.Future
115116
self._timeout = None
116117

117118
async def connect(self):
@@ -152,7 +153,7 @@ async def acquire(self) -> PoolConnectionProxy:
152153

153154
self._maybe_cancel_inactive_callback()
154155

155-
proxy = PoolConnectionProxy(self, self._con)
156+
self._proxy = proxy = PoolConnectionProxy(self, self._con)
156157

157158
if self._setup is not None:
158159
try:
@@ -163,31 +164,29 @@ async def acquire(self) -> PoolConnectionProxy:
163164
# we close it. A new connection will be created
164165
# when `acquire` is called again.
165166
try:
166-
proxy._detach()
167167
# Use `close` to close the connection gracefully.
168168
# An exception in `setup` isn't necessarily caused
169169
# by an IO or a protocol error.
170170
await self._con.close()
171171
finally:
172-
self._con = None
173172
raise ex
174173

175-
self._in_use = True
174+
self._in_use = self._pool._loop.create_future()
175+
176176
return proxy
177177

178178
async def release(self, timeout):
179-
assert self._in_use
180-
self._in_use = False
181-
self._timeout = None
179+
assert self._in_use is not None
182180

183181
if self._con.is_closed():
184-
self._con = None
182+
# When closing, pool connections perform the necessary
183+
# cleanup, so we don't have to do anything else here.
184+
return
185185

186-
elif self._con._protocol.queries_count >= self._max_queries:
187-
try:
188-
await self._con.close(timeout=timeout)
189-
finally:
190-
self._con = None
186+
self._timeout = None
187+
188+
if self._con._protocol.queries_count >= self._max_queries:
189+
await self._con.close(timeout=timeout)
191190

192191
else:
193192
try:
@@ -213,52 +212,71 @@ async def release(self, timeout):
213212
# an IO error, so terminate the connection.
214213
self._con.terminate()
215214
finally:
216-
self._con = None
217215
raise ex
218216

217+
self._release()
218+
219219
assert self._inactive_callback is None
220220
if self._max_inactive_time and self._con is not None:
221221
self._inactive_callback = self._pool._loop.call_later(
222222
self._max_inactive_time, self._deactivate_connection)
223223

224-
async def close(self):
225-
self._maybe_cancel_inactive_callback()
226-
if self._con is None:
227-
return
228-
if self._con.is_closed():
229-
self._con = None
224+
async def wait_until_released(self):
225+
if self._in_use is None:
230226
return
227+
else:
228+
await self._in_use
231229

232-
try:
230+
async def close(self):
231+
if self._con is not None:
232+
# Connection.close() will call _release_on_close() to
233+
# finish holder cleanup.
233234
await self._con.close()
234-
finally:
235-
self._con = None
236235

237236
def terminate(self):
238-
self._maybe_cancel_inactive_callback()
239-
if self._con is None:
240-
return
241-
if self._con.is_closed():
242-
self._con = None
243-
return
244-
245-
try:
237+
if self._con is not None:
238+
# Connection.terminate() will call _release_on_close() to
239+
# finish holder cleanup.
246240
self._con.terminate()
247-
finally:
248-
self._con = None
249241

250242
def _maybe_cancel_inactive_callback(self):
251243
if self._inactive_callback is not None:
252244
self._inactive_callback.cancel()
253245
self._inactive_callback = None
254246

255247
def _deactivate_connection(self):
256-
assert not self._in_use
257-
if self._con is None or self._con.is_closed():
258-
return
259-
self._con.terminate()
248+
assert self._in_use is None
249+
if self._con is not None:
250+
self._con.terminate()
251+
# Must call clear_connection, because _deactivate_connection
252+
# is called when the connection is *not* checked out, and
253+
# so terminate() above will not call the below.
254+
self._release_on_close()
255+
256+
def _release_on_close(self):
257+
self._maybe_cancel_inactive_callback()
258+
self._release()
260259
self._con = None
261260

261+
def _release(self):
262+
"""Release this connection holder."""
263+
if self._in_use is None:
264+
# The holder is not checked out.
265+
return
266+
267+
if not self._in_use.done():
268+
self._in_use.set_result(None)
269+
self._in_use = None
270+
271+
# Let go of the connection proxy.
272+
if self._proxy is not None:
273+
if self._proxy._con is not None:
274+
self._proxy._detach()
275+
self._proxy = None
276+
277+
# Put ourselves back to the pool queue.
278+
self._pool._queue.put_nowait(self)
279+
262280

263281
class Pool:
264282
"""A connection pool.
@@ -273,7 +291,7 @@ class Pool:
273291

274292
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
275293
'_working_addr', '_working_config', '_working_params',
276-
'_holders', '_initialized', '_closed',
294+
'_holders', '_initialized', '_closing', '_closed',
277295
'_connection_class')
278296

279297
def __init__(self, *connect_args,
@@ -322,6 +340,7 @@ def __init__(self, *connect_args,
322340

323341
self._connection_class = connection_class
324342

343+
self._closing = False
325344
self._closed = False
326345

327346
for _ in range(max_size):
@@ -468,7 +487,10 @@ async def _acquire_impl():
468487
ch._timeout = timeout
469488
return proxy
470489

490+
if self._closing:
491+
raise exceptions.InterfaceError('pool is closing')
471492
self._check_init()
493+
472494
if timeout is None:
473495
return await _acquire_impl()
474496
else:
@@ -488,14 +510,6 @@ async def release(self, connection, *, timeout=None):
488510
.. versionchanged:: 0.14.0
489511
Added the *timeout* parameter.
490512
"""
491-
async def _release_impl(ch: PoolConnectionHolder, timeout: float):
492-
try:
493-
await ch.release(timeout)
494-
finally:
495-
self._queue.put_nowait(ch)
496-
497-
self._check_init()
498-
499513
if (type(connection) is not PoolConnectionProxy or
500514
connection._holder._pool is not self):
501515
raise exceptions.InterfaceError(
@@ -507,35 +521,64 @@ async def _release_impl(ch: PoolConnectionHolder, timeout: float):
507521
# Already released, do nothing.
508522
return
509523

510-
con = connection._detach()
524+
self._check_init()
525+
526+
con = connection._con
511527
con._on_release()
528+
ch = connection._holder
512529

513530
if timeout is None:
514-
timeout = connection._holder._timeout
531+
timeout = ch._timeout
515532

516533
# Use asyncio.shield() to guarantee that task cancellation
517534
# does not prevent the connection from being returned to the
518535
# pool properly.
519-
return await asyncio.shield(
520-
_release_impl(connection._holder, timeout), loop=self._loop)
536+
return await asyncio.shield(ch.release(timeout), loop=self._loop)
521537

522538
async def close(self):
523-
"""Gracefully close all connections in the pool."""
539+
"""Attempt to gracefully close all connections in the pool.
540+
541+
Wait until all pool connections are released, close them and
542+
shut down the pool. If any error (including cancellation) occurs
543+
in ``close()`` the pool will terminate by calling
544+
:meth:'Pool.terminate() <pool.Pool.terminate>`.
545+
546+
.. versionchanged:: 0.16.0
547+
``close()`` now waits until all pool connections are released
548+
before closing them and the pool. Errors raised in ``close()``
549+
will cause immediate pool termination.
550+
"""
524551
if self._closed:
525552
return
526553
self._check_init()
527-
self._closed = True
528-
coros = [ch.close() for ch in self._holders]
529-
await asyncio.gather(*coros, loop=self._loop)
554+
555+
self._closing = True
556+
557+
try:
558+
release_coros = [
559+
ch.wait_until_released() for ch in self._holders]
560+
await asyncio.gather(*release_coros, loop=self._loop)
561+
562+
close_coros = [
563+
ch.close() for ch in self._holders]
564+
await asyncio.gather(*close_coros, loop=self._loop)
565+
566+
except Exception:
567+
self.terminate()
568+
raise
569+
570+
finally:
571+
self._closed = True
572+
self._closing = False
530573

531574
def terminate(self):
532575
"""Terminate all connections in the pool."""
533576
if self._closed:
534577
return
535578
self._check_init()
536-
self._closed = True
537579
for ch in self._holders:
538580
ch.terminate()
581+
self._closed = True
539582

540583
def _check_init(self):
541584
if not self._initialized:

tests/test_adversity.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ async def test_pool_release_timeout(self):
3232
self.proxy.trigger_connectivity_loss()
3333
finally:
3434
self.proxy.restore_connectivity()
35-
await pool.close()
35+
pool.terminate()
3636

3737
@tb.with_timeout(30.0)
3838
async def test_pool_handles_abrupt_connection_loss(self):
@@ -57,8 +57,11 @@ def kill_connectivity():
5757
timeout=cmd_timeout, command_timeout=cmd_timeout)
5858

5959
with self.assertRunUnder(worst_runtime):
60-
async with new_pool as pool:
60+
pool = await new_pool
61+
try:
6162
workers = [worker(pool) for _ in range(concurrency)]
6263
self.loop.call_later(1, kill_connectivity)
6364
await asyncio.gather(
6465
*workers, loop=self.loop, return_exceptions=True)
66+
finally:
67+
pool.terminate()

tests/test_cache_invalidation.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ async def test_prepare_cache_invalidation_in_pool(self):
8585

8686
finally:
8787
await self.con.execute('DROP TABLE tab1')
88+
await pool.release(con2)
89+
await pool.release(con1)
8890
await pool.close()
8991

9092
async def test_type_cache_invalidation_in_transaction(self):
@@ -303,6 +305,9 @@ async def test_type_cache_invalidation_in_pool(self):
303305
finally:
304306
await self.con.execute('DROP TABLE tab1')
305307
await self.con.execute('DROP TYPE typ1')
308+
await pool.release(con2)
309+
await pool.release(con1)
306310
await pool.close()
311+
await pool_chk.release(con_chk)
307312
await pool_chk.close()
308313
await self.con.execute('DROP DATABASE testdb')

0 commit comments

Comments
 (0)