From 5eecaf9b9815c5faedf65d43fe525efdc36d7441 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Thu, 1 Feb 2018 08:52:41 +0100 Subject: [PATCH] Use custom classes to pass client signals parameters #2686 (#2699) * Use custom classes to pass client signals parameters #2686 This will allow aiohttp development add new parameters in the future without break the signals signature. * Fixed tracing uts * Fix import order * Improved documenatation --- CHANGES/2686.feature | 1 + aiohttp/connector.py | 14 +-- aiohttp/tracing.py | 164 +++++++++++++++++++++++++--------- docs/client_advanced.rst | 10 +-- docs/tracing_reference.rst | 167 +++++++++++++++++++++++++++++------ tests/test_client_session.py | 35 ++++---- tests/test_connector.py | 39 +++++--- tests/test_tracing.py | 101 +++++++++++++++++---- 8 files changed, 405 insertions(+), 126 deletions(-) create mode 100644 CHANGES/2686.feature diff --git a/CHANGES/2686.feature b/CHANGES/2686.feature new file mode 100644 index 00000000000..eab2eb89b6f --- /dev/null +++ b/CHANGES/2686.feature @@ -0,0 +1 @@ +Use custom classes to pass client signals parameters diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 7eed3168eca..acb898f7cd1 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -660,14 +660,14 @@ async def _resolve_host(self, host, port, traces=None): if traces: for trace in traces: - await trace.send_dns_resolvehost_start() + await trace.send_dns_resolvehost_start(host) res = (await self._resolver.resolve( host, port, family=self._family)) if traces: for trace in traces: - await trace.send_dns_resolvehost_end() + await trace.send_dns_resolvehost_end(host) return res @@ -678,26 +678,26 @@ async def _resolve_host(self, host, port, traces=None): if traces: for trace in traces: - await trace.send_dns_cache_hit() + await trace.send_dns_cache_hit(host) return self._cached_hosts.next_addrs(key) if key in self._throttle_dns_events: if traces: for trace in traces: - await trace.send_dns_cache_hit() + await trace.send_dns_cache_hit(host) await self._throttle_dns_events[key].wait() else: if traces: for trace in traces: - await trace.send_dns_cache_miss() + await trace.send_dns_cache_miss(host) self._throttle_dns_events[key] = \ EventResultOrError(self._loop) try: if traces: for trace in traces: - await trace.send_dns_resolvehost_start() + await trace.send_dns_resolvehost_start(host) addrs = await \ asyncio.shield(self._resolver.resolve(host, @@ -706,7 +706,7 @@ async def _resolve_host(self, host, port, traces=None): loop=self._loop) if traces: for trace in traces: - await trace.send_dns_resolvehost_end() + await trace.send_dns_resolvehost_end(host) self._cached_hosts.add(key, addrs) self._throttle_dns_events[key].set() diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index cbfd1597352..507eeb37995 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -1,9 +1,22 @@ from types import SimpleNamespace +import attr +from multidict import CIMultiDict +from yarl import URL + +from .client_reqrep import ClientResponse from .signals import Signal -__all__ = ('TraceConfig',) +__all__ = ( + 'TraceConfig', 'TraceRequestStartParams', 'TraceRequestEndParams', + 'TraceRequestExceptionParams', 'TraceConnectionQueuedStartParams', + 'TraceConnectionQueuedEndParams', 'TraceConnectionCreateStartParams', + 'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams', + 'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams', + 'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams', + 'TraceRequestRedirectParams' +) class TraceConfig: @@ -100,6 +113,90 @@ def on_dns_cache_miss(self): return self._on_dns_cache_miss +@attr.s(frozen=True, slots=True) +class TraceRequestStartParams: + """ Parameters sent by the `on_request_start` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + + +@attr.s(frozen=True, slots=True) +class TraceRequestEndParams: + """ Parameters sent by the `on_request_end` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + resp = attr.ib(type=ClientResponse) + + +@attr.s(frozen=True, slots=True) +class TraceRequestExceptionParams: + """ Parameters sent by the `on_request_exception` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + exception = attr.ib(type=Exception) + + +@attr.s(frozen=True, slots=True) +class TraceRequestRedirectParams: + """ Parameters sent by the `on_request_redirect` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + resp = attr.ib(type=ClientResponse) + + +@attr.s(frozen=True, slots=True) +class TraceConnectionQueuedStartParams: + """ Parameters sent by the `on_connection_queued_start` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionQueuedEndParams: + """ Parameters sent by the `on_connection_queued_end` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionCreateStartParams: + """ Parameters sent by the `on_connection_create_start` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionCreateEndParams: + """ Parameters sent by the `on_connection_create_end` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionReuseconnParams: + """ Parameters sent by the `on_connection_reuseconn` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceDnsResolveHostStartParams: + """ Parameters sent by the `on_dns_resolvehost_start` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsResolveHostEndParams: + """ Parameters sent by the `on_dns_resolvehost_end` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsCacheHitParams: + """ Parameters sent by the `on_dns_cache_hit` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsCacheMissParams: + """ Parameters sent by the `on_dns_cache_miss` signal""" + host = attr.ib(type=str) + + class Trace: """ Internal class used to keep together the main dependencies used at the moment of send a signal.""" @@ -109,106 +206,93 @@ def __init__(self, session, trace_config, trace_config_ctx): self._trace_config_ctx = trace_config_ctx self._session = session - async def send_request_start(self, *args, **kwargs): + async def send_request_start(self, method, url, headers): return await self._trace_config.on_request_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestStartParams(method, url, headers) ) - async def send_request_end(self, *args, **kwargs): + async def send_request_end(self, method, url, headers, response): return await self._trace_config.on_request_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestEndParams(method, url, headers, response) ) - async def send_request_exception(self, *args, **kwargs): + async def send_request_exception(self, method, url, headers, exception): return await self._trace_config.on_request_exception.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestExceptionParams(method, url, headers, exception) ) - async def send_request_redirect(self, *args, **kwargs): + async def send_request_redirect(self, method, url, headers, response): return await self._trace_config._on_request_redirect.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestRedirectParams(method, url, headers, response) ) - async def send_connection_queued_start(self, *args, **kwargs): + async def send_connection_queued_start(self): return await self._trace_config.on_connection_queued_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionQueuedStartParams() ) - async def send_connection_queued_end(self, *args, **kwargs): + async def send_connection_queued_end(self): return await self._trace_config.on_connection_queued_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionQueuedEndParams() ) - async def send_connection_create_start(self, *args, **kwargs): + async def send_connection_create_start(self): return await self._trace_config.on_connection_create_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionCreateStartParams() ) - async def send_connection_create_end(self, *args, **kwargs): + async def send_connection_create_end(self): return await self._trace_config.on_connection_create_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionCreateEndParams() ) - async def send_connection_reuseconn(self, *args, **kwargs): + async def send_connection_reuseconn(self): return await self._trace_config.on_connection_reuseconn.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionReuseconnParams() ) - async def send_dns_resolvehost_start(self, *args, **kwargs): + async def send_dns_resolvehost_start(self, host): return await self._trace_config.on_dns_resolvehost_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsResolveHostStartParams(host) ) - async def send_dns_resolvehost_end(self, *args, **kwargs): + async def send_dns_resolvehost_end(self, host): return await self._trace_config.on_dns_resolvehost_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsResolveHostEndParams(host) ) - async def send_dns_cache_hit(self, *args, **kwargs): + async def send_dns_cache_hit(self, host): return await self._trace_config.on_dns_cache_hit.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsCacheHitParams(host) ) - async def send_dns_cache_miss(self, *args, **kwargs): + async def send_dns_cache_miss(self, host): return await self._trace_config.on_dns_cache_miss.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsCacheMissParams(host) ) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 2c4b19624ee..135877252a8 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -225,10 +225,10 @@ disabled. The following snippet shows how the start and the end signals of a request flow can be followed:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): print("Starting request") - async def on_request_end(session, trace_config_ctx, resp): + async def on_request_end(session, trace_config_ctx, params): print("Ending request") trace_config = aiohttp.TraceConfig() @@ -259,10 +259,10 @@ share the state through to the different signals that belong to the same request and to the same :class:`TraceConfig` class, perhaps:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): trace_config_ctx.start = session.loop.time() - async def on_request_end(session, trace_config_ctx, resp): + async def on_request_end(session, trace_config_ctx, params): elapsed = session.loop.time() - trace_config_ctx.start print("Request took {}".format(elapsed)) @@ -280,7 +280,7 @@ factory. This param is useful to pass data that is only available at request time, perhaps:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): print(trace_config_ctx.trace_request_ctx) diff --git a/docs/tracing_reference.rst b/docs/tracing_reference.rst index 71689bde02f..e370b9dd1e7 100644 --- a/docs/tracing_reference.rst +++ b/docs/tracing_reference.rst @@ -29,10 +29,10 @@ the request flow. .. attribute:: on_request_start Property that gives access to the signals that will be executed when a - request starts, based on the :class:`~signals.Signal` implementation. + request starts, based on the :class:`aiohttp.signals.Signal` implementation. - The coroutines listening will receive as a param the ``session``, - ``trace_config_ctx``, ``method``, ``url`` and ``headers``. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceRequestStartParams` instance .. versionadded:: 3.0 @@ -41,8 +41,8 @@ the request flow. Property that gives access to the signals that will be executed when a redirect happens during a request flow. - The coroutines that are listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceRequestRedirectParams` instance .. versionadded:: 3.0 @@ -51,8 +51,8 @@ the request flow. Property that gives access to the signals that will be executed when a request ends. - The coroutines that are listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceRequestEndParams` instance .. versionadded:: 3.0 @@ -61,8 +61,8 @@ the request flow. Property that gives access to the signals that will be executed when a request finishes with an exception. - The coroutines listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``exception`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceRequestExceptionParams` instance .. versionadded:: 3.0 @@ -71,8 +71,8 @@ the request flow. Property that gives access to the signals that will be executed when a request has been queued waiting for an available connection. - The coroutines that are listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceConnectionQueuedStartParams` instance .. versionadded:: 3.0 @@ -81,8 +81,8 @@ the request flow. Property that gives access to the signals that will be executed when a request that was queued already has an available connection. - The coroutines that are listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceConnectionQueuedEndParams` instance .. versionadded:: 3.0 @@ -91,8 +91,8 @@ the request flow. Property that gives access to the signals that will be executed when a request creates a new connection. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceConnectionCreateStartParams` instance .. versionadded:: 3.0 @@ -101,8 +101,8 @@ the request flow. Property that gives access to the signals that will be executed when a request that created a new connection finishes its creation. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceConnectionCreateEndParams` instance .. versionadded:: 3.0 @@ -111,8 +111,8 @@ the request flow. Property that gives access to the signals that will be executed when a request reuses a connection. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceConnectionReuseconnParams` instance .. versionadded:: 3.0 @@ -121,8 +121,8 @@ the request flow. Property that gives access to the signals that will be executed when a request starts to resolve the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceDnsResolveHostStartParams` instance .. versionadded:: 3.0 @@ -131,8 +131,8 @@ the request flow. Property that gives access to the signals that will be executed when a request finishes to resolve the domain related with the request. - The coroutines listening will receive the ``session`` and ``trace_config_ctx`` - params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceDnsResolveHostEndParams` instance .. versionadded:: 3.0 @@ -142,8 +142,8 @@ the request flow. request was able to use a cached DNS resolution for the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceDnsCacheHitParams` instance .. versionadded:: 3.0 @@ -153,7 +153,120 @@ the request flow. request was not able to use a cached DNS resolution for the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The signal handler signature is ``async def on_request_start(session, context, params): ...`` + where ``params`` is :class:`aiohttp.TraceDnsCacheMissParams` instance .. versionadded:: 3.0 + +.. class:: TraceRequestStartParams + + .. attribute:: method + + Method that will be used to make the request. + + .. attribute:: url + + URL that will be used for the request. + + .. attribute:: headers + + Headers that will be used for the request, can be mutated. + +.. class:: TraceRequestEndParams + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: resp + + Response :class:`ClientReponse`. + + +.. class:: TraceRequestExceptionParams + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: exception + + Exception raised during the request. + +.. class:: TraceRequestRedirectParams + + .. attribute:: method + + Method used to get this redirect request. + + .. attribute:: url + + URL used for this redirect request. + + .. attribute:: headers + + Headers used for this redirect. + + .. attribute:: resp + + Response :class:`ClientReponse` got from the redirect. + +.. class:: TraceConnectionQueuedStartParams + + There are no attributes right now. + +.. class:: TraceConnectionQueuedEndParams + + There are no attributes right now. + +.. class:: TraceConnectionCreateStartParams + + There are no attributes right now. + +.. class:: TraceConnectionCreateEndParams + + There are no attributes right now. + +.. class:: TraceConnectionReuseconnParams + + There are no attributes right now. + +.. class:: TraceDnsResolveHostStartParams + + .. attribute:: Host + + Host that will be resolved. + +.. class:: TraceDnsResolveHostEndParams + + .. attribute:: Host + + Host that has been resolved. + +.. class:: TraceDnsCacheHitParams + + .. attribute:: Host + + Host found in the cache. + +.. class:: TraceDnsCacheMissParams + + .. attribute:: Host + + Host didn't find the cache. diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 3ed641e3d7a..972106779bc 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -479,18 +479,22 @@ async def test_request_tracing(loop): on_request_start.assert_called_once_with( session, trace_config_ctx, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict() + aiohttp.TraceRequestStartParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict() + ) ) on_request_end.assert_called_once_with( session, trace_config_ctx, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict(), - resp + aiohttp.TraceRequestEndParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + resp + ) ) assert not on_request_redirect.called @@ -524,10 +528,12 @@ async def test_request_tracing_exception(loop): on_request_exception.assert_called_once_with( session, mock.ANY, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict(), - error + aiohttp.TraceRequestExceptionParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + error + ) ) assert not on_request_end.called @@ -544,11 +550,8 @@ def __init__(self, *args, **kwargs): async def new_headers( session, trace_config_ctx, - method, - url, - headers, - trace_request_ctx=None): - headers['foo'] = 'bar' + data): + data.headers['foo'] = 'bar' trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(new_headers) diff --git a/tests/test_connector.py b/tests/test_connector.py index ddfe4871985..f8536de9fce 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -666,14 +666,17 @@ async def test_tcp_connector_dns_tracing(loop, dns_response): on_dns_resolvehost_start.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ) - on_dns_resolvehost_start.assert_called_once_with( + on_dns_resolvehost_end.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ) on_dns_cache_miss.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsCacheMissParams('localhost') ) assert not on_dns_cache_hit.called @@ -685,6 +688,7 @@ async def test_tcp_connector_dns_tracing(loop, dns_response): on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsCacheHitParams('localhost') ) @@ -738,21 +742,25 @@ async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response): on_dns_resolvehost_start.assert_has_calls([ mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ), mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ) ]) on_dns_resolvehost_end.assert_has_calls([ mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ), mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ) ]) @@ -793,11 +801,13 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response): await asyncio.sleep(0, loop=loop) on_dns_cache_hit.assert_called_once_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsCacheHitParams('localhost') ) on_dns_cache_miss.assert_called_once_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsCacheMissParams('localhost') ) @@ -933,11 +943,13 @@ async def test_connect_tracing(loop): await conn.connect(req, traces=traces) on_connection_create_start.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionCreateStartParams() ) on_connection_create_end.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionCreateEndParams() ) @@ -1333,11 +1345,13 @@ async def f(): ) on_connection_queued_start.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionQueuedStartParams() ) on_connection_queued_end.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionQueuedEndParams() ) connection2.release() @@ -1381,7 +1395,8 @@ async def test_connect_reuseconn_tracing(loop, key): on_connection_reuseconn.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionReuseconnParams() ) conn.close() diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 10ea6bbcf11..bc837bf3af1 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -4,7 +4,19 @@ import pytest -from aiohttp.tracing import Trace, TraceConfig +from aiohttp.tracing import (Trace, TraceConfig, + TraceConnectionCreateEndParams, + TraceConnectionCreateStartParams, + TraceConnectionQueuedEndParams, + TraceConnectionQueuedStartParams, + TraceConnectionReuseconnParams, + TraceDnsCacheHitParams, TraceDnsCacheMissParams, + TraceDnsResolveHostEndParams, + TraceDnsResolveHostStartParams, + TraceRequestEndParams, + TraceRequestExceptionParams, + TraceRequestRedirectParams, + TraceRequestStartParams) class TestTraceConfig: @@ -45,23 +57,74 @@ def test_freeze(self): class TestTrace: - @pytest.mark.parametrize('signal', [ - 'request_start', - 'request_end', - 'request_exception', - 'request_redirect', - 'connection_queued_start', - 'connection_queued_end', - 'connection_create_start', - 'connection_create_end', - 'connection_reuseconn', - 'dns_resolvehost_start', - 'dns_resolvehost_end', - 'dns_cache_hit', - 'dns_cache_miss' + @pytest.mark.parametrize('signal,params,param_obj', [ + ( + 'request_start', + (Mock(), Mock(), Mock()), + TraceRequestStartParams + ), + ( + 'request_end', + (Mock(), Mock(), Mock(), Mock()), + TraceRequestEndParams + ), + ( + 'request_exception', + (Mock(), Mock(), Mock(), Mock()), + TraceRequestExceptionParams + ), + ( + 'request_redirect', + (Mock(), Mock(), Mock(), Mock()), + TraceRequestRedirectParams + ), + ( + 'connection_queued_start', + (), + TraceConnectionQueuedStartParams + ), + ( + 'connection_queued_end', + (), + TraceConnectionQueuedEndParams + ), + ( + 'connection_create_start', + (), + TraceConnectionCreateStartParams + ), + ( + 'connection_create_end', + (), + TraceConnectionCreateEndParams + ), + ( + 'connection_reuseconn', + (), + TraceConnectionReuseconnParams + ), + ( + 'dns_resolvehost_start', + (Mock(),), + TraceDnsResolveHostStartParams + ), + ( + 'dns_resolvehost_end', + (Mock(),), + TraceDnsResolveHostEndParams + ), + ( + 'dns_cache_hit', + (Mock(),), + TraceDnsCacheHitParams + ), + ( + 'dns_cache_miss', + (Mock(),), + TraceDnsCacheMissParams + ) ]) - async def test_send(self, loop, signal): - param = Mock() + async def test_send(self, loop, signal, params, param_obj): session = Mock() trace_request_ctx = Mock() callback = Mock(side_effect=asyncio.coroutine(Mock())) @@ -74,10 +137,10 @@ async def test_send(self, loop, signal): trace_config, trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx) ) - await getattr(trace, "send_%s" % signal)(param) + await getattr(trace, "send_%s" % signal)(*params) callback.assert_called_once_with( session, SimpleNamespace(trace_request_ctx=trace_request_ctx), - param, + param_obj(*params) )