Skip to content
77 changes: 74 additions & 3 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Connection(metaclass=ConnectionMeta):
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')
'_source_traceback', '_query_loggers', '__weakref__')

def __init__(self, protocol, transport, loop,
addr,
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(self, protocol, transport, loop,
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
self._query_loggers = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -221,6 +222,30 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))

def add_query_logger(self, callback):
"""Add a logger that will be called when queries are executed.

:param callable callback:
A callable or a coroutine function receiving two arguments:
**connection**: a Connection the callback is registered with.
**query**: a LoggedQuery containing the query, args, timeout, and
elapsed.

.. versionadded:: 0.28.0
"""
self._query_loggers.add(_Callback.from_callable(callback))

def remove_query_logger(self, callback):
"""Remove a query logger callback.

:param callable callback:
The callable or coroutine function that was passed to
:meth:`Connection.add_query_logger`.

.. versionadded:: 0.28.0
"""
self._query_loggers.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -314,7 +339,11 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()

if not args:
return await self._protocol.query(query, timeout)
start = time.monotonic()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a context manager would probably be nicer/reduce duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, change incoming.

result = await self._protocol.query(query, timeout)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result

_, status, _ = await self._execute(
query,
Expand Down Expand Up @@ -1667,6 +1696,20 @@ async def _execute(
)
return result

def logger(self, callback):
return _LoggingContext(self, callback)

def _log_query(self, query, args, timeout, elapsed):
if not self._query_loggers:
return
con_ref = self._unwrap()
record = LoggedQuery(query, args, timeout, elapsed)
for cb in self._query_loggers:
if cb.is_async:
self._loop.create_task(cb.cb(con_ref, record))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing a con_ref is probably unnecessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I likely wouldn't use it, so happy to remove it, but I put it there so you could potentially log queries by host.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concern here is potentially retaining references to free-d connections. Other callbacks take it, of course, but that's an API decision I've come to regret. Perhaps we can pass connection's addr and params instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

else:
self._loop.call_soon(cb.cb, con_ref, record)

async def __execute(
self,
query,
Expand All @@ -1681,20 +1724,27 @@ async def __execute(
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
start = time.monotonic()
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result, stmt

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
start = time.monotonic()
result, _ = await self._do_execute(query, executor, timeout)
elapsed = time.monotonic() - start
self._log_query(query, args, timeout, elapsed)
return result

async def _do_execute(
Expand Down Expand Up @@ -2323,6 +2373,27 @@ class _ConnectionProxy:
__slots__ = ()


LoggedQuery = collections.namedtuple(
'LoggedQuery',
['query', 'args', 'timeout', 'elapsed'])
LoggedQuery.__doc__ = 'Log record of an executed query.'


class _LoggingContext:
__slots__ = ('_conn', '_cb')

def __init__(self, conn, callback):
self._conn = conn
self._cb = callback

def __enter__(self):
self._conn.add_query_logger(self._cb)
return self

def __exit__(self, *exc_info):
self._conn.remove_query_logger(self._cb)


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
Expand Down
20 changes: 20 additions & 0 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from asyncpg import _testbase as tb


class TestQueryLogging(tb.ConnectedTestCase):

async def test_logging_context(self):
queries = asyncio.Queue()

def query_saver(conn, record):
queries.put_nowait(record)

with self.con.logger(query_saver):
self.assertEqual(len(self.con._query_loggers), 1)
await self.con.execute("SELECT 1")

record = await queries.get()
self.assertEqual(record.query, "SELECT 1")
self.assertEqual(len(self.con._query_loggers), 0)