Skip to content

Synchronize fixtures on demand #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 79 additions & 93 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,38 @@
Iterator,
Sequence,
)
from types import AsyncGeneratorType, CoroutineType
from typing import (
Any,
Callable,
Literal,
TypeVar,
Union,
cast,
overload,
)

import pluggy
import pytest
from _pytest.scope import Scope
from pytest import (
Collector,
Config,
FixtureDef,
FixtureRequest,
Function,
Item,
Mark,
Metafunc,
MonkeyPatch,
Parser,
PytestCollectionWarning,
PytestDeprecationWarning,
PytestPluginManager,
)

if sys.version_info >= (3, 10):
from typing import ParamSpec
from typing import Concatenate, ParamSpec
else:
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec


_ScopeName = Literal["session", "package", "module", "class", "function"]
Expand Down Expand Up @@ -230,45 +230,16 @@ def pytest_report_header(config: Config) -> list[str]:
]


def _preprocess_async_fixtures(
collector: Collector,
processed_fixturedefs: set[FixtureDef],
) -> None:
config = collector.config
default_loop_scope = config.getini("asyncio_default_fixture_loop_scope")
asyncio_mode = _get_asyncio_mode(config)
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
assert fixturemanager is not None
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
func = fixturedef.func
if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen(
func
):
continue
if asyncio_mode == Mode.STRICT and not _is_asyncio_fixture_function(func):
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
continue
loop_scope = (
getattr(func, "_loop_scope", None)
or default_loop_scope
or fixturedef.scope
)
_make_asyncio_fixture_function(func, loop_scope)
if "request" not in fixturedef.argnames:
fixturedef.argnames += ("request",)
_synchronize_async_fixture(fixturedef)
assert _is_asyncio_fixture_function(fixturedef.func)
processed_fixturedefs.add(fixturedef)


def _synchronize_async_fixture(fixturedef: FixtureDef) -> None:
"""Wraps the fixture function of an async fixture in a synchronous function."""
def _fixture_synchronizer(
fixturedef: FixtureDef, event_loop: AbstractEventLoop
) -> Callable:
"""Returns a synchronous function evaluating the specified fixture."""
if inspect.isasyncgenfunction(fixturedef.func):
_wrap_asyncgen_fixture(fixturedef)
return _wrap_asyncgen_fixture(fixturedef.func, event_loop)
elif inspect.iscoroutinefunction(fixturedef.func):
_wrap_async_fixture(fixturedef)
return _wrap_async_fixture(fixturedef.func, event_loop)
else:
return fixturedef.func


def _add_kwargs(
Expand Down Expand Up @@ -299,18 +270,26 @@ def _perhaps_rebind_fixture_func(func: _T, instance: Any | None) -> _T:
return func


def _wrap_asyncgen_fixture(fixturedef: FixtureDef) -> None:
fixture = fixturedef.func
AsyncGenFixtureParams = ParamSpec("AsyncGenFixtureParams")
AsyncGenFixtureYieldType = TypeVar("AsyncGenFixtureYieldType")

@functools.wraps(fixture)
def _asyncgen_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
func = _perhaps_rebind_fixture_func(fixture, request.instance)
event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture(
request, func
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)
gen_obj = func(**_add_kwargs(func, kwargs, request))

def _wrap_asyncgen_fixture(
fixture_function: Callable[
AsyncGenFixtureParams, AsyncGeneratorType[AsyncGenFixtureYieldType, Any]
],
event_loop: AbstractEventLoop,
) -> Callable[
Concatenate[FixtureRequest, AsyncGenFixtureParams], AsyncGenFixtureYieldType
]:
@functools.wraps(fixture_function)
def _asyncgen_fixture_wrapper(
request: FixtureRequest,
*args: AsyncGenFixtureParams.args,
**kwargs: AsyncGenFixtureParams.kwargs,
):
func = _perhaps_rebind_fixture_func(fixture_function, request.instance)
gen_obj = func(*args, **_add_kwargs(func, kwargs, request))

async def setup():
res = await gen_obj.__anext__() # type: ignore[union-attr]
Expand Down Expand Up @@ -343,23 +322,30 @@ async def async_finalizer() -> None:
request.addfinalizer(finalizer)
return result

fixturedef.func = _asyncgen_fixture_wrapper # type: ignore[misc]
return _asyncgen_fixture_wrapper


def _wrap_async_fixture(fixturedef: FixtureDef) -> None:
fixture = fixturedef.func
AsyncFixtureParams = ParamSpec("AsyncFixtureParams")
AsyncFixtureReturnType = TypeVar("AsyncFixtureReturnType")

@functools.wraps(fixture)
def _async_fixture_wrapper(request: FixtureRequest, **kwargs: Any):
func = _perhaps_rebind_fixture_func(fixture, request.instance)
event_loop_fixture_id = _get_event_loop_fixture_id_for_async_fixture(
request, func
)
event_loop = request.getfixturevalue(event_loop_fixture_id)
kwargs.pop(event_loop_fixture_id, None)

def _wrap_async_fixture(
fixture_function: Callable[
AsyncFixtureParams, CoroutineType[Any, Any, AsyncFixtureReturnType]
],
event_loop: AbstractEventLoop,
) -> Callable[Concatenate[FixtureRequest, AsyncFixtureParams], AsyncFixtureReturnType]:

@functools.wraps(fixture_function) # type: ignore[arg-type]
def _async_fixture_wrapper(
request: FixtureRequest,
*args: AsyncFixtureParams.args,
**kwargs: AsyncFixtureParams.kwargs,
):
func = _perhaps_rebind_fixture_func(fixture_function, request.instance)

async def setup():
res = await func(**_add_kwargs(func, kwargs, request))
res = await func(*args, **_add_kwargs(func, kwargs, request))
return res

context = contextvars.copy_context()
Expand All @@ -380,19 +366,7 @@ async def setup():

return result

fixturedef.func = _async_fixture_wrapper # type: ignore[misc]


def _get_event_loop_fixture_id_for_async_fixture(
request: FixtureRequest, func: Any
) -> str:
default_loop_scope = cast(
_ScopeName, request.config.getini("asyncio_default_fixture_loop_scope")
)
loop_scope = (
getattr(func, "_loop_scope", None) or default_loop_scope or request.scope
)
return f"_{loop_scope}_event_loop"
return _async_fixture_wrapper


def _create_task_in_context(
Expand Down Expand Up @@ -573,22 +547,6 @@ def runtest(self) -> None:
super().runtest()


_HOLDER: set[FixtureDef] = set()


# The function name needs to start with "pytest_"
# see https://github.com/pytest-dev/pytest/issues/11307
@pytest.hookimpl(specname="pytest_pycollect_makeitem", tryfirst=True)
def pytest_pycollect_makeitem_preprocess_async_fixtures(
collector: pytest.Module | pytest.Class, name: str, obj: object
) -> pytest.Item | pytest.Collector | list[pytest.Item | pytest.Collector] | None:
"""A pytest hook to collect asyncio coroutines."""
if not collector.funcnamefilter(name):
return None
_preprocess_async_fixtures(collector, _HOLDER)
return None


# The function name needs to start with "pytest_"
# see https://github.com/pytest-dev/pytest/issues/11307
@pytest.hookimpl(specname="pytest_pycollect_makeitem", hookwrapper=True)
Expand Down Expand Up @@ -803,6 +761,34 @@ def pytest_runtest_setup(item: pytest.Item) -> None:
)


@pytest.hookimpl(wrapper=True)
def pytest_fixture_setup(fixturedef: FixtureDef, request) -> object | None:
asyncio_mode = _get_asyncio_mode(request.config)
if not _is_asyncio_fixture_function(fixturedef.func):
if asyncio_mode == Mode.STRICT:
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
return (yield)
if not _is_coroutine_or_asyncgen(fixturedef.func):
return (yield)
default_loop_scope = request.config.getini("asyncio_default_fixture_loop_scope")
loop_scope = (
getattr(fixturedef.func, "_loop_scope", None)
or default_loop_scope
or fixturedef.scope
)
event_loop_fixture_id = f"_{loop_scope}_event_loop"
event_loop = request.getfixturevalue(event_loop_fixture_id)
synchronizer = _fixture_synchronizer(fixturedef, event_loop)
_make_asyncio_fixture_function(synchronizer, loop_scope)
with MonkeyPatch.context() as c:
if "request" not in fixturedef.argnames:
c.setattr(fixturedef, "argnames", (*fixturedef.argnames, "request"))
c.setattr(fixturedef, "func", synchronizer)
hook_result = yield
return hook_result


_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR = """\
An asyncio pytest marker defines both "scope" and "loop_scope", \
but it should only use "loop_scope".
Expand Down
Loading