66import inspect
77import socket
88import warnings
9+ from typing import (
10+ Any ,
11+ AsyncIterator ,
12+ Awaitable ,
13+ Callable ,
14+ Dict ,
15+ Iterable ,
16+ Iterator ,
17+ List ,
18+ Optional ,
19+ Set ,
20+ TypeVar ,
21+ Union ,
22+ cast ,
23+ overload ,
24+ )
925
1026import pytest
27+ from typing_extensions import Literal
28+
29+ _R = TypeVar ("_R" )
30+
31+ _ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
32+ _T = TypeVar ("_T" )
33+
34+ SimpleFixtureFunction = TypeVar (
35+ "SimpleFixtureFunction" , bound = Callable [..., Awaitable [_R ]]
36+ )
37+ FactoryFixtureFunction = TypeVar (
38+ "FactoryFixtureFunction" , bound = Callable [..., AsyncIterator [_R ]]
39+ )
40+ FixtureFunction = Union [SimpleFixtureFunction , FactoryFixtureFunction ]
41+ FixtureFunctionMarker = Callable [[FixtureFunction ], FixtureFunction ]
42+
43+ Config = Any # pytest < 7.0
44+ PytestPluginManager = Any # pytest < 7.0
45+ FixtureDef = Any # pytest < 7.0
46+ Parser = Any # pytest < 7.0
47+ SubRequest = Any # pytest < 7.0
1148
1249
1350class Mode (str , enum .Enum ):
@@ -41,7 +78,7 @@ class Mode(str, enum.Enum):
4178"""
4279
4380
44- def pytest_addoption (parser , pluginmanager ) :
81+ def pytest_addoption (parser : Parser , pluginmanager : PytestPluginManager ) -> None :
4582 group = parser .getgroup ("asyncio" )
4683 group .addoption (
4784 "--asyncio-mode" ,
@@ -58,49 +95,87 @@ def pytest_addoption(parser, pluginmanager):
5895 )
5996
6097
61- def fixture (fixture_function = None , ** kwargs ):
98+ @overload
99+ def fixture (
100+ fixture_function : FixtureFunction ,
101+ * ,
102+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
103+ params : Optional [Iterable [object ]] = ...,
104+ autouse : bool = ...,
105+ ids : Optional [
106+ Union [
107+ Iterable [Union [None , str , float , int , bool ]],
108+ Callable [[Any ], Optional [object ]],
109+ ]
110+ ] = ...,
111+ name : Optional [str ] = ...,
112+ ) -> FixtureFunction :
113+ ...
114+
115+
116+ @overload
117+ def fixture (
118+ fixture_function : None = ...,
119+ * ,
120+ scope : "Union[_ScopeName, Callable[[str, Config], _ScopeName]]" = ...,
121+ params : Optional [Iterable [object ]] = ...,
122+ autouse : bool = ...,
123+ ids : Optional [
124+ Union [
125+ Iterable [Union [None , str , float , int , bool ]],
126+ Callable [[Any ], Optional [object ]],
127+ ]
128+ ] = ...,
129+ name : Optional [str ] = None ,
130+ ) -> FixtureFunctionMarker :
131+ ...
132+
133+
134+ def fixture (
135+ fixture_function : Optional [FixtureFunction ] = None , ** kwargs : Any
136+ ) -> Union [FixtureFunction , FixtureFunctionMarker ]:
62137 if fixture_function is not None :
63138 _set_explicit_asyncio_mark (fixture_function )
64139 return pytest .fixture (fixture_function , ** kwargs )
65140
66141 else :
67142
68143 @functools .wraps (fixture )
69- def inner (fixture_function ) :
144+ def inner (fixture_function : FixtureFunction ) -> FixtureFunction :
70145 return fixture (fixture_function , ** kwargs )
71146
72147 return inner
73148
74149
75- def _has_explicit_asyncio_mark (obj ) :
150+ def _has_explicit_asyncio_mark (obj : Any ) -> bool :
76151 obj = getattr (obj , "__func__" , obj ) # instance method maybe?
77152 return getattr (obj , "_force_asyncio_fixture" , False )
78153
79154
80- def _set_explicit_asyncio_mark (obj ) :
155+ def _set_explicit_asyncio_mark (obj : Any ) -> None :
81156 if hasattr (obj , "__func__" ):
82157 # instance method, check the function object
83158 obj = obj .__func__
84159 obj ._force_asyncio_fixture = True
85160
86161
87- def _is_coroutine (obj ) :
162+ def _is_coroutine (obj : Any ) -> bool :
88163 """Check to see if an object is really an asyncio coroutine."""
89164 return asyncio .iscoroutinefunction (obj ) or inspect .isgeneratorfunction (obj )
90165
91166
92- def _is_coroutine_or_asyncgen (obj ) :
167+ def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
93168 return _is_coroutine (obj ) or inspect .isasyncgenfunction (obj )
94169
95170
96- def _get_asyncio_mode (config ) :
171+ def _get_asyncio_mode (config : Config ) -> Mode :
97172 val = config .getoption ("asyncio_mode" )
98173 if val is None :
99174 val = config .getini ("asyncio_mode" )
100175 return Mode (val )
101176
102177
103- def pytest_configure (config ) :
178+ def pytest_configure (config : Config ) -> None :
104179 """Inject documentation."""
105180 config .addinivalue_line (
106181 "markers" ,
@@ -113,10 +188,14 @@ def pytest_configure(config):
113188
114189
115190@pytest .mark .tryfirst
116- def pytest_pycollect_makeitem (collector , name , obj ):
191+ def pytest_pycollect_makeitem (
192+ collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
193+ ) -> Union [
194+ None , pytest .Item , pytest .Collector , List [Union [pytest .Item , pytest .Collector ]]
195+ ]:
117196 """A pytest hook to collect asyncio coroutines."""
118197 if not collector .funcnamefilter (name ):
119- return
198+ return None
120199 if (
121200 _is_coroutine (obj )
122201 or _is_hypothesis_test (obj )
@@ -131,10 +210,11 @@ def pytest_pycollect_makeitem(collector, name, obj):
131210 ret = list (collector ._genfunctions (name , obj ))
132211 for elem in ret :
133212 elem .add_marker ("asyncio" )
134- return ret
213+ return ret # type: ignore[return-value]
214+ return None
135215
136216
137- def _hypothesis_test_wraps_coroutine (function ) :
217+ def _hypothesis_test_wraps_coroutine (function : Any ) -> bool :
138218 return _is_coroutine (function .hypothesis .inner_test )
139219
140220
@@ -144,19 +224,19 @@ class FixtureStripper:
144224 REQUEST = "request"
145225 EVENT_LOOP = "event_loop"
146226
147- def __init__ (self , fixturedef ) :
227+ def __init__ (self , fixturedef : FixtureDef ) -> None :
148228 self .fixturedef = fixturedef
149- self .to_strip = set ()
229+ self .to_strip : Set [ str ] = set ()
150230
151- def add (self , name ) :
231+ def add (self , name : str ) -> None :
152232 """Add fixture name to fixturedef
153233 and record in to_strip list (If not previously included)"""
154234 if name in self .fixturedef .argnames :
155235 return
156236 self .fixturedef .argnames += (name ,)
157237 self .to_strip .add (name )
158238
159- def get_and_strip_from (self , name , data_dict ) :
239+ def get_and_strip_from (self , name : str , data_dict : Dict [ str , _T ]) -> _T :
160240 """Strip name from data, and return value"""
161241 result = data_dict [name ]
162242 if name in self .to_strip :
@@ -165,7 +245,7 @@ def get_and_strip_from(self, name, data_dict):
165245
166246
167247@pytest .hookimpl (trylast = True )
168- def pytest_fixture_post_finalizer (fixturedef , request ) :
248+ def pytest_fixture_post_finalizer (fixturedef : FixtureDef , request : SubRequest ) -> None :
169249 """Called after fixture teardown"""
170250 if fixturedef .argname == "event_loop" :
171251 policy = asyncio .get_event_loop_policy ()
@@ -182,7 +262,9 @@ def pytest_fixture_post_finalizer(fixturedef, request):
182262
183263
184264@pytest .hookimpl (hookwrapper = True )
185- def pytest_fixture_setup (fixturedef , request ):
265+ def pytest_fixture_setup (
266+ fixturedef : FixtureDef , request : SubRequest
267+ ) -> Optional [object ]:
186268 """Adjust the event loop policy when an event loop is produced."""
187269 if fixturedef .argname == "event_loop" :
188270 outcome = yield
@@ -295,39 +377,43 @@ async def setup():
295377
296378
297379@pytest .hookimpl (tryfirst = True , hookwrapper = True )
298- def pytest_pyfunc_call (pyfuncitem ) :
380+ def pytest_pyfunc_call (pyfuncitem : pytest . Function ) -> Optional [ object ] :
299381 """
300382 Pytest hook called before a test case is run.
301383
302384 Wraps marked tests in a synchronous function
303385 where the wrapped test coroutine is executed in an event loop.
304386 """
305387 if "asyncio" in pyfuncitem .keywords :
388+ funcargs : Dict [str , object ] = pyfuncitem .funcargs # type: ignore[name-defined]
389+ loop = cast (asyncio .AbstractEventLoop , funcargs ["event_loop" ])
306390 if _is_hypothesis_test (pyfuncitem .obj ):
307391 pyfuncitem .obj .hypothesis .inner_test = wrap_in_sync (
308392 pyfuncitem .obj .hypothesis .inner_test ,
309- _loop = pyfuncitem . funcargs [ "event_loop" ] ,
393+ _loop = loop ,
310394 )
311395 else :
312396 pyfuncitem .obj = wrap_in_sync (
313- pyfuncitem .obj , _loop = pyfuncitem .funcargs ["event_loop" ]
397+ pyfuncitem .obj ,
398+ _loop = loop ,
314399 )
315400 yield
316401
317402
318- def _is_hypothesis_test (function ) -> bool :
403+ def _is_hypothesis_test (function : Any ) -> bool :
319404 return getattr (function , "is_hypothesis_test" , False )
320405
321406
322- def wrap_in_sync (func , _loop ):
407+ def wrap_in_sync (func : Callable [..., Awaitable [ Any ]], _loop : asyncio . AbstractEventLoop ):
323408 """Return a sync wrapper around an async function executing it in the
324409 current event loop."""
325410
326411 # if the function is already wrapped, we rewrap using the original one
327412 # not using __wrapped__ because the original function may already be
328413 # a wrapped one
329- if hasattr (func , "_raw_test_func" ):
330- func = func ._raw_test_func
414+ raw_func = getattr (func , "_raw_test_func" , None )
415+ if raw_func is not None :
416+ func = raw_func
331417
332418 @functools .wraps (func )
333419 def inner (** kwargs ):
@@ -344,20 +430,22 @@ def inner(**kwargs):
344430 task .exception ()
345431 raise
346432
347- inner ._raw_test_func = func
433+ inner ._raw_test_func = func # type: ignore[attr-defined]
348434 return inner
349435
350436
351- def pytest_runtest_setup (item ) :
437+ def pytest_runtest_setup (item : pytest . Item ) -> None :
352438 if "asyncio" in item .keywords :
439+ fixturenames = item .fixturenames # type: ignore[attr-defined]
353440 # inject an event loop fixture for all async tests
354- if "event_loop" in item .fixturenames :
355- item .fixturenames .remove ("event_loop" )
356- item .fixturenames .insert (0 , "event_loop" )
441+ if "event_loop" in fixturenames :
442+ fixturenames .remove ("event_loop" )
443+ fixturenames .insert (0 , "event_loop" )
444+ obj = item .obj # type: ignore[attr-defined]
357445 if (
358446 item .get_closest_marker ("asyncio" ) is not None
359- and not getattr (item . obj , "hypothesis" , False )
360- and getattr (item . obj , "is_hypothesis_test" , False )
447+ and not getattr (obj , "hypothesis" , False )
448+ and getattr (obj , "is_hypothesis_test" , False )
361449 ):
362450 pytest .fail (
363451 "test function `%r` is using Hypothesis, but pytest-asyncio "
@@ -366,32 +454,32 @@ def pytest_runtest_setup(item):
366454
367455
368456@pytest .fixture
369- def event_loop (request ) :
457+ def event_loop (request : pytest . FixtureRequest ) -> Iterator [ asyncio . AbstractEventLoop ] :
370458 """Create an instance of the default event loop for each test case."""
371459 loop = asyncio .get_event_loop_policy ().new_event_loop ()
372460 yield loop
373461 loop .close ()
374462
375463
376- def _unused_port (socket_type ) :
464+ def _unused_port (socket_type : int ) -> int :
377465 """Find an unused localhost port from 1024-65535 and return it."""
378466 with contextlib .closing (socket .socket (type = socket_type )) as sock :
379467 sock .bind (("127.0.0.1" , 0 ))
380468 return sock .getsockname ()[1 ]
381469
382470
383471@pytest .fixture
384- def unused_tcp_port ():
472+ def unused_tcp_port () -> int :
385473 return _unused_port (socket .SOCK_STREAM )
386474
387475
388476@pytest .fixture
389- def unused_udp_port ():
477+ def unused_udp_port () -> int :
390478 return _unused_port (socket .SOCK_DGRAM )
391479
392480
393481@pytest .fixture (scope = "session" )
394- def unused_tcp_port_factory ():
482+ def unused_tcp_port_factory () -> Callable [[], int ] :
395483 """A factory function, producing different unused TCP ports."""
396484 produced = set ()
397485
@@ -410,7 +498,7 @@ def factory():
410498
411499
412500@pytest .fixture (scope = "session" )
413- def unused_udp_port_factory ():
501+ def unused_udp_port_factory () -> Callable [[], int ] :
414502 """A factory function, producing different unused UDP ports."""
415503 produced = set ()
416504
0 commit comments