Skip to content

Commit

Permalink
Refactor pytest_pycollect_makeitems (#421)
Browse files Browse the repository at this point in the history
* refactor: Extracted logic for marking tests in auto mode into pytest_collection_modifyitems.

pytest_pycollect_makeitem currently calls `Collector._genfunctions`, in order to delegate further collection of test items to the current pytest collector. It does so only to add the asyncio mark to async tests after the items have been collected.

Rather than relying on a call to the protected `Collector._genfunctions` method the marking logic was moved to the pytest_collection_modifyitems hook, which is called at the end of the collection phase. This change removes the call to protected functions and makes the code easier to understand.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Hoist up check for asyncio mode before trying to modify function items.

pytest_collection_modifyitems has no effect when asyncio mode is not set to AUTO. Moving the mode check out of the loop prevents unnecessary work.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Renamed _set_explicit_asyncio_mark and _has_explicit_asyncio_mark to _make_asyncio_fixture_function and _is_asyncio_fixture_function, respectively.

The new names reflect the purpose of the functions, instead of what they do. The new names also avoid confusion with pytest markers by not using "mark" in their names.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Removed obsolete elif clause.

Legacy mode has been removed, so we don't need an elif to check if we're in AUTO mode.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Renamed the "holder" argument to _preprocess_async_fixtures to "processed_fixturedefs" to better reflect the purpose of the variable.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Simplified branching structure of _preprocess_async_fixtures.

It is safe to call _make_asyncio_fixture_function without checking whether the fixture function has been converted to an asyncio fixture function, because each fixture is only processed once in the loop.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Simplified logic in _preprocess_async_fixtures.

Merged two if-clauses both of which cause the current fixturedef to be skipped.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Extracted _inject_fixture_argnames from _preprocess_async_fixtures in order to improve readability.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>

* refactor: Extracted _synchronize_async_fixture from _preprocess_async_fixtures in order to improve readability.

Signed-off-by: Michael Seifert <m.seifert@digitalernachschub.de>
  • Loading branch information
seifertm authored Oct 11, 2022
1 parent d45ab21 commit 907c461
Showing 1 changed file with 69 additions and 48 deletions.
117 changes: 69 additions & 48 deletions pytest_asyncio/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

import pytest
from pytest import Function, Session, Item

if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -121,7 +122,7 @@ def fixture(
fixture_function: Optional[FixtureFunction] = None, **kwargs: Any
) -> Union[FixtureFunction, FixtureFunctionMarker]:
if fixture_function is not None:
_set_explicit_asyncio_mark(fixture_function)
_make_asyncio_fixture_function(fixture_function)
return pytest.fixture(fixture_function, **kwargs)

else:
Expand All @@ -133,12 +134,12 @@ def inner(fixture_function: FixtureFunction) -> FixtureFunction:
return inner


def _has_explicit_asyncio_mark(obj: Any) -> bool:
def _is_asyncio_fixture_function(obj: Any) -> bool:
obj = getattr(obj, "__func__", obj) # instance method maybe?
return getattr(obj, "_force_asyncio_fixture", False)


def _set_explicit_asyncio_mark(obj: Any) -> None:
def _make_asyncio_fixture_function(obj: Any) -> None:
if hasattr(obj, "__func__"):
# instance method, check the function object
obj = obj.__func__
Expand Down Expand Up @@ -186,41 +187,51 @@ def pytest_report_header(config: Config) -> List[str]:
return [f"asyncio: mode={mode}"]


def _preprocess_async_fixtures(config: Config, holder: Set[FixtureDef]) -> None:
def _preprocess_async_fixtures(
config: Config,
processed_fixturedefs: Set[FixtureDef],
) -> None:
asyncio_mode = _get_asyncio_mode(config)
fixturemanager = config.pluginmanager.get_plugin("funcmanage")
for fixtures in fixturemanager._arg2fixturedefs.values():
for fixturedef in fixtures:
if fixturedef in holder:
continue
func = fixturedef.func
if not _is_coroutine_or_asyncgen(func):
# Nothing to do with a regular fixture function
if fixturedef in processed_fixturedefs or not _is_coroutine_or_asyncgen(
func
):
continue
if not _is_asyncio_fixture_function(func) and asyncio_mode == Mode.STRICT:
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
continue
if not _has_explicit_asyncio_mark(func):
if asyncio_mode == Mode.STRICT:
# Ignore async fixtures without explicit asyncio mark in strict mode
# This applies to pytest_trio fixtures, for example
continue
elif asyncio_mode == Mode.AUTO:
# Enforce asyncio mode if 'auto'
_set_explicit_asyncio_mark(func)
_make_asyncio_fixture_function(func)
_inject_fixture_argnames(fixturedef)
_synchronize_async_fixture(fixturedef)
assert _is_asyncio_fixture_function(fixturedef.func)
processed_fixturedefs.add(fixturedef)

to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)

if to_add:
fixturedef.argnames += tuple(to_add)
def _inject_fixture_argnames(fixturedef: FixtureDef) -> None:
"""
Ensures that `request` and `event_loop` are arguments of the specified fixture.
"""
to_add = []
for name in ("request", "event_loop"):
if name not in fixturedef.argnames:
to_add.append(name)
if to_add:
fixturedef.argnames += tuple(to_add)

if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)

assert _has_explicit_asyncio_mark(fixturedef.func)
holder.add(fixturedef)
def _synchronize_async_fixture(fixturedef: FixtureDef) -> None:
"""
Wraps the fixture function of an async fixture in a synchronous function.
"""
func = fixturedef.func
if inspect.isasyncgenfunction(func):
fixturedef.func = _wrap_asyncgen(func)
elif inspect.iscoroutinefunction(func):
fixturedef.func = _wrap_async(func)


def _add_kwargs(
Expand Down Expand Up @@ -290,36 +301,46 @@ async def setup() -> _R:

@pytest.mark.tryfirst
def pytest_pycollect_makeitem(
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
collector: Union[pytest.Module, pytest.Class], name: str, obj: object
) -> Union[
None, pytest.Item, pytest.Collector, List[Union[pytest.Item, pytest.Collector]]
]:
"""A pytest hook to collect asyncio coroutines."""
if not collector.funcnamefilter(name):
return None
_preprocess_async_fixtures(collector.config, _HOLDER)
if isinstance(obj, staticmethod):
# staticmethods need to be unwrapped.
obj = obj.__func__
if (
_is_coroutine(obj)
or _is_hypothesis_test(obj)
and _hypothesis_test_wraps_coroutine(obj)
):
item = pytest.Function.from_parent(collector, name=name)
marker = item.get_closest_marker("asyncio")
if marker is not None:
return list(collector._genfunctions(name, obj))
else:
if _get_asyncio_mode(item.config) == Mode.AUTO:
# implicitly add asyncio marker if asyncio mode is on
ret = list(collector._genfunctions(name, obj))
for elem in ret:
elem.add_marker("asyncio")
return ret # type: ignore[return-value]
return None


def pytest_collection_modifyitems(
session: Session, config: Config, items: List[Item]
) -> None:
"""
Marks collected async test items as `asyncio` tests.
The mark is only applied in `AUTO` mode. It is applied to:
- coroutines
- staticmethods wrapping coroutines
- Hypothesis tests wrapping coroutines
"""
if _get_asyncio_mode(config) != Mode.AUTO:
return
function_items = (item for item in items if isinstance(item, Function))
for function_item in function_items:
function = function_item.obj
if isinstance(function, staticmethod):
# staticmethods need to be unwrapped.
function = function.__func__
if (
_is_coroutine(function)
or _is_hypothesis_test(function)
and _hypothesis_test_wraps_coroutine(function)
):
function_item.add_marker("asyncio")


def _hypothesis_test_wraps_coroutine(function: Any) -> bool:
return _is_coroutine(function.hypothesis.inner_test)

Expand Down

0 comments on commit 907c461

Please sign in to comment.