Skip to content

Commit ee8b864

Browse files
committed
fix(aiomysql): avoid wrapping pooled connections multiple times
1 parent fcaeb10 commit ee8b864

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

newrelic/hooks/database_aiomysql.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ async def _wrap_pool__acquire(wrapped, instance, args, kwargs):
7878
with FunctionTrace(name=callable_name(wrapped), terminal=True, rollup=rollup, source=wrapped):
7979
connection = await wrapped(*args, **kwargs)
8080
connection_kwargs = getattr(instance, "_conn_kwargs", {})
81+
82+
if isinstance(connection, AsyncConnectionWrapper):
83+
return connection
84+
8185
return AsyncConnectionWrapper(connection, dbapi2_module, (((), connection_kwargs)))
8286

8387
return _wrap_pool__acquire
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import asyncio
2+
3+
from newrelic.hooks.database_aiomysql import AsyncConnectionWrapper, wrap_pool__acquire
4+
5+
6+
class DummyPool:
7+
_conn_kwargs = {"host": "localhost"}
8+
9+
10+
class DummyDBModule:
11+
_nr_database_product = "MySQL"
12+
13+
14+
def test_pool_acquire_does_not_double_wrap():
15+
async def run():
16+
connection_store = {"value": object()}
17+
18+
async def underlying_acquire(*_args, **_kwargs):
19+
return connection_store["value"]
20+
21+
wrapper = wrap_pool__acquire(DummyDBModule)
22+
pool = DummyPool()
23+
24+
first = await wrapper(underlying_acquire, pool, (), {})
25+
assert isinstance(first, AsyncConnectionWrapper)
26+
inner = first._nr_next_object
27+
assert not isinstance(inner, AsyncConnectionWrapper)
28+
29+
# Simulate connection being returned to the pool. The pool will now hand
30+
# back the already wrapped connection when acquire() is invoked again.
31+
connection_store["value"] = first
32+
33+
second = await wrapper(underlying_acquire, pool, (), {})
34+
assert second is first
35+
assert not isinstance(second._nr_next_object, AsyncConnectionWrapper)
36+
asyncio.run(run())

0 commit comments

Comments
 (0)