@@ -165,7 +165,7 @@ def _set_explicit_asyncio_mark(obj: Any) -> None:
165165
166166def _is_coroutine (obj : Any ) -> bool :
167167 """Check to see if an object is really an asyncio coroutine."""
168- return asyncio .iscoroutinefunction (obj ) or inspect . isgeneratorfunction ( obj )
168+ return asyncio .iscoroutinefunction (obj )
169169
170170
171171def _is_coroutine_or_asyncgen (obj : Any ) -> bool :
@@ -198,6 +198,118 @@ def pytest_report_header(config: Config) -> List[str]:
198198 return [f"asyncio: mode={ mode } " ]
199199
200200
201+ def _preprocess_async_fixtures (config : Config , holder : Set [FixtureDef ]) -> None :
202+ asyncio_mode = _get_asyncio_mode (config )
203+ fixturemanager = config .pluginmanager .get_plugin ("funcmanage" )
204+ for fixtures in fixturemanager ._arg2fixturedefs .values ():
205+ for fixturedef in fixtures :
206+ if fixturedef is holder :
207+ continue
208+ func = fixturedef .func
209+ if not _is_coroutine_or_asyncgen (func ):
210+ # Nothing to do with a regular fixture function
211+ continue
212+ if not _has_explicit_asyncio_mark (func ):
213+ if asyncio_mode == Mode .AUTO :
214+ # Enforce asyncio mode if 'auto'
215+ _set_explicit_asyncio_mark (func )
216+ elif asyncio_mode == Mode .LEGACY :
217+ _set_explicit_asyncio_mark (func )
218+ try :
219+ code = func .__code__
220+ except AttributeError :
221+ code = func .__func__ .__code__
222+ name = (
223+ f"<fixture { func .__qualname__ } , file={ code .co_filename } , "
224+ f"line={ code .co_firstlineno } >"
225+ )
226+ warnings .warn (
227+ LEGACY_ASYNCIO_FIXTURE .format (name = name ),
228+ DeprecationWarning ,
229+ )
230+
231+ to_add = []
232+ for name in ("request" , "event_loop" ):
233+ if name not in fixturedef .argnames :
234+ to_add .append (name )
235+
236+ if to_add :
237+ fixturedef .argnames += tuple (to_add )
238+
239+ if inspect .isasyncgenfunction (func ):
240+ fixturedef .func = _wrap_asyncgen (func )
241+ elif inspect .iscoroutinefunction (func ):
242+ fixturedef .func = _wrap_async (func )
243+
244+ assert _has_explicit_asyncio_mark (fixturedef .func )
245+ holder .add (fixturedef )
246+
247+
248+ def _add_kwargs (
249+ func : Callable [..., Any ],
250+ kwargs : Dict [str , Any ],
251+ event_loop : asyncio .AbstractEventLoop ,
252+ request : SubRequest ,
253+ ) -> Dict [str , Any ]:
254+ sig = inspect .signature (func )
255+ ret = kwargs .copy ()
256+ if "request" in sig .parameters :
257+ ret ["request" ] = request
258+ if "event_loop" in sig .parameters :
259+ ret ["event_loop" ] = event_loop
260+ return ret
261+
262+
263+ def _wrap_asyncgen (func : Callable [..., AsyncIterator [_R ]]) -> Callable [..., _R ]:
264+ @functools .wraps (func )
265+ def _asyncgen_fixture_wrapper (
266+ event_loop : asyncio .AbstractEventLoop , request : SubRequest , ** kwargs : Any
267+ ) -> _R :
268+ gen_obj = func (** _add_kwargs (func , kwargs , event_loop , request ))
269+
270+ async def setup () -> _R :
271+ res = await gen_obj .__anext__ ()
272+ return res
273+
274+ def finalizer () -> None :
275+ """Yield again, to finalize."""
276+
277+ async def async_finalizer () -> None :
278+ try :
279+ await gen_obj .__anext__ ()
280+ except StopAsyncIteration :
281+ pass
282+ else :
283+ msg = "Async generator fixture didn't stop."
284+ msg += "Yield only once."
285+ raise ValueError (msg )
286+
287+ event_loop .run_until_complete (async_finalizer ())
288+
289+ result = event_loop .run_until_complete (setup ())
290+ request .addfinalizer (finalizer )
291+ return result
292+
293+ return _asyncgen_fixture_wrapper
294+
295+
296+ def _wrap_async (func : Callable [..., Awaitable [_R ]]) -> Callable [..., _R ]:
297+ @functools .wraps (func )
298+ def _async_fixture_wrapper (
299+ event_loop : asyncio .AbstractEventLoop , request : SubRequest , ** kwargs : Any
300+ ) -> _R :
301+ async def setup () -> _R :
302+ res = await func (** _add_kwargs (func , kwargs , event_loop , request ))
303+ return res
304+
305+ return event_loop .run_until_complete (setup ())
306+
307+ return _async_fixture_wrapper
308+
309+
310+ _HOLDER : Set [FixtureDef ] = set ()
311+
312+
201313@pytest .mark .tryfirst
202314def pytest_pycollect_makeitem (
203315 collector : Union [pytest .Module , pytest .Class ], name : str , obj : object
@@ -212,6 +324,7 @@ def pytest_pycollect_makeitem(
212324 or _is_hypothesis_test (obj )
213325 and _hypothesis_test_wraps_coroutine (obj )
214326 ):
327+ _preprocess_async_fixtures (collector .config , _HOLDER )
215328 item = pytest .Function .from_parent (collector , name = name )
216329 marker = item .get_closest_marker ("asyncio" )
217330 if marker is not None :
@@ -230,31 +343,6 @@ def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
230343 return _is_coroutine (function .hypothesis .inner_test )
231344
232345
233- class FixtureStripper :
234- """Include additional Fixture, and then strip them"""
235-
236- EVENT_LOOP = "event_loop"
237-
238- def __init__ (self , fixturedef : FixtureDef ) -> None :
239- self .fixturedef = fixturedef
240- self .to_strip : Set [str ] = set ()
241-
242- def add (self , name : str ) -> None :
243- """Add fixture name to fixturedef
244- and record in to_strip list (If not previously included)"""
245- if name in self .fixturedef .argnames :
246- return
247- self .fixturedef .argnames += (name ,)
248- self .to_strip .add (name )
249-
250- def get_and_strip_from (self , name : str , data_dict : Dict [str , _T ]) -> _T :
251- """Strip name from data, and return value"""
252- result = data_dict [name ]
253- if name in self .to_strip :
254- del data_dict [name ]
255- return result
256-
257-
258346@pytest .hookimpl (trylast = True )
259347def pytest_fixture_post_finalizer (fixturedef : FixtureDef , request : SubRequest ) -> None :
260348 """Called after fixture teardown"""
@@ -291,95 +379,6 @@ def pytest_fixture_setup(
291379 policy .set_event_loop (loop )
292380 return
293381
294- func = fixturedef .func
295- if not _is_coroutine_or_asyncgen (func ):
296- # Nothing to do with a regular fixture function
297- yield
298- return
299-
300- config = request .node .config
301- asyncio_mode = _get_asyncio_mode (config )
302-
303- if not _has_explicit_asyncio_mark (func ):
304- if asyncio_mode == Mode .AUTO :
305- # Enforce asyncio mode if 'auto'
306- _set_explicit_asyncio_mark (func )
307- elif asyncio_mode == Mode .LEGACY :
308- _set_explicit_asyncio_mark (func )
309- try :
310- code = func .__code__
311- except AttributeError :
312- code = func .__func__ .__code__
313- name = (
314- f"<fixture { func .__qualname__ } , file={ code .co_filename } , "
315- f"line={ code .co_firstlineno } >"
316- )
317- warnings .warn (
318- LEGACY_ASYNCIO_FIXTURE .format (name = name ),
319- DeprecationWarning ,
320- )
321- else :
322- # asyncio_mode is STRICT,
323- # don't handle fixtures that are not explicitly marked
324- yield
325- return
326-
327- if inspect .isasyncgenfunction (func ):
328- # This is an async generator function. Wrap it accordingly.
329- generator = func
330-
331- fixture_stripper = FixtureStripper (fixturedef )
332- fixture_stripper .add (FixtureStripper .EVENT_LOOP )
333-
334- def wrapper (* args , ** kwargs ):
335- loop = fixture_stripper .get_and_strip_from (
336- FixtureStripper .EVENT_LOOP , kwargs
337- )
338-
339- gen_obj = generator (* args , ** kwargs )
340-
341- async def setup ():
342- res = await gen_obj .__anext__ ()
343- return res
344-
345- def finalizer ():
346- """Yield again, to finalize."""
347-
348- async def async_finalizer ():
349- try :
350- await gen_obj .__anext__ ()
351- except StopAsyncIteration :
352- pass
353- else :
354- msg = "Async generator fixture didn't stop."
355- msg += "Yield only once."
356- raise ValueError (msg )
357-
358- loop .run_until_complete (async_finalizer ())
359-
360- result = loop .run_until_complete (setup ())
361- request .addfinalizer (finalizer )
362- return result
363-
364- fixturedef .func = wrapper
365- elif inspect .iscoroutinefunction (func ):
366- coro = func
367-
368- fixture_stripper = FixtureStripper (fixturedef )
369- fixture_stripper .add (FixtureStripper .EVENT_LOOP )
370-
371- def wrapper (* args , ** kwargs ):
372- loop = fixture_stripper .get_and_strip_from (
373- FixtureStripper .EVENT_LOOP , kwargs
374- )
375-
376- async def setup ():
377- res = await coro (* args , ** kwargs )
378- return res
379-
380- return loop .run_until_complete (setup ())
381-
382- fixturedef .func = wrapper
383382 yield
384383
385384
0 commit comments