Skip to content

Improved fixture reuse by new param keys that can be derived from API ids #9420

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ Thomas Grainger
Thomas Hisch
Tim Hoffmann
Tim Strazny
Tobias Deiminger
Tom Dalton
Tom Viner
Tomáš Gavenčiak
Expand Down
4 changes: 4 additions & 0 deletions changelog/8914.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
If fixtures had been indirectly parameterized via test function, e.g. using the
``@pytest.mark.parametrize(indirect=True)`` marker, reordering of tests for the least possible fixture setup/teardown
cycles did not work. Optimized test groups can now be determined either explicitly by passing parameter ids, or
implicitly if the parameter value is hashable.
27 changes: 16 additions & 11 deletions src/_pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from collections import defaultdict
from collections import deque
from collections.abc import Hashable
from contextlib import suppress
from pathlib import Path
from types import TracebackType
Expand Down Expand Up @@ -248,21 +249,21 @@ def get_parametrized_fixture_keys(item: nodes.Item, scope: Scope) -> Iterator[_K
pass
else:
cs: CallSpec2 = callspec
# cs.indices.items() is random order of argnames. Need to
# cs.param_keys.items() is random order of argnames. Need to
# sort this so that different calls to
# get_parametrized_fixture_keys will be deterministic.
for argname, param_index in sorted(cs.indices.items()):
for argname, param_key in sorted(cs.param_keys.items()):
if cs._arg2scope[argname] != scope:
continue
if scope is Scope.Session:
key: _Key = (argname, param_index)
key: _Key = (argname, param_key)
elif scope is Scope.Package:
key = (argname, param_index, item.path.parent)
key = (argname, param_key, item.path.parent)
elif scope is Scope.Module:
key = (argname, param_index, item.path)
key = (argname, param_key, item.path)
elif scope is Scope.Class:
item_cls = item.cls # type: ignore[attr-defined]
key = (argname, param_index, item.path, item_cls)
key = (argname, param_key, item.path, item_cls)
else:
assert_never(scope)
yield key
Expand Down Expand Up @@ -601,6 +602,7 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None:
except (AttributeError, ValueError):
param = NOTSET
param_index = 0
param_key = ""
has_params = fixturedef.params is not None
fixtures_not_supported = getattr(funcitem, "nofuncargs", False)
if has_params and fixtures_not_supported:
Expand Down Expand Up @@ -640,13 +642,14 @@ def _compute_fixture_value(self, fixturedef: "FixtureDef[object]") -> None:
fail(msg, pytrace=False)
else:
param_index = funcitem.callspec.indices[argname]
param_key = funcitem.callspec.param_keys[argname]
# If a parametrize invocation set a scope it will override
# the static scope defined with the fixture function.
with suppress(KeyError):
scope = funcitem.callspec._arg2scope[argname]

subrequest = SubRequest(
self, scope, param, param_index, fixturedef, _ispytest=True
self, scope, param, param_index, param_key, fixturedef, _ispytest=True
)

# Check if a higher-level scoped fixture accesses a lower level one.
Expand Down Expand Up @@ -731,6 +734,7 @@ def __init__(
scope: Scope,
param: Any,
param_index: int,
param_key: Hashable,
fixturedef: "FixtureDef[object]",
*,
_ispytest: bool = False,
Expand All @@ -741,6 +745,7 @@ def __init__(
if param is not NOTSET:
self.param = param
self.param_index = param_index
self.param_key = param_key
self._scope = scope
self._fixturedef = fixturedef
self._pyfuncitem = request._pyfuncitem
Expand Down Expand Up @@ -1012,10 +1017,10 @@ def execute(self, request: SubRequest) -> FixtureValue:

my_cache_key = self.cache_key(request)
if self.cached_result is not None:
# note: comparison with `==` can fail (or be expensive) for e.g.
# numpy arrays (#6497).
cache_key = self.cached_result[1]
if my_cache_key is cache_key:
# Note: Comparison with `==` may be implemented as (possibly expensive)
# deep by-value comparison. See _pytest.python.SafeHashWrapper for details.
if my_cache_key == cache_key:
if self.cached_result[2] is not None:
_, val, tb = self.cached_result[2]
raise val.with_traceback(tb)
Expand All @@ -1032,7 +1037,7 @@ def execute(self, request: SubRequest) -> FixtureValue:
return result

def cache_key(self, request: SubRequest) -> object:
return request.param_index if not hasattr(request, "param") else request.param
return request.param_key

def __repr__(self) -> str:
return "<FixtureDef argname={!r} scope={!r} baseid={!r}>".format(
Expand Down
97 changes: 88 additions & 9 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from collections import Counter
from collections import defaultdict
from collections.abc import Hashable
from functools import partial
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -929,6 +930,38 @@ def hasnew(obj: object) -> bool:
return False


@attr.s(auto_attribs=True, eq=False, slots=True)
class SafeHashWrapper:
"""Wrap an arbitrary type so that it becomes comparable with guaranteed constraints.

Constraints:
- SafeHashWrapper(a) == SafeHashWrapper(b) will never raise an exception
- SafeHashWrapper(a) == SafeHashWrapper(b) will always return bool
(oddly some inner types wouldn't, e.g. numpy.array([0]) == numpy.array([0]) returns List)
- SafeHashWrapper(a) is always hashable
- if SafeHashWrapper(a) == SafeHashWrapper(b),
then hash(SafeHashWrapper(a)) == hash(SafeHashWrapper(b))

It works by falling back to identity compare in case constraints couldn't be met otherwise.
"""

obj: Any

def __eq__(self, other: object) -> bool:
if isinstance(self.obj, Hashable) and isinstance(other, Hashable):
try:
res = self.obj == other
return bool(res)
except Exception:
pass
return self.obj is other

def __hash__(self) -> int:
if isinstance(self.obj, Hashable):
return hash(self.obj)
return hash(id(self.obj))


@final
@attr.s(frozen=True, auto_attribs=True, slots=True)
class IdMaker:
Expand Down Expand Up @@ -976,6 +1009,27 @@ def make_unique_parameterset_ids(self) -> List[str]:
id_suffixes[id] += 1
return resolved_ids

def make_parameter_keys(self) -> Iterable[Dict[str, Hashable]]:
"""Make hashable parameter keys for each ParameterSet.

For each ParameterSet, generates a dict mapping each parameter to its key.

This key will be considered (along with the arguments name) to determine
if parameters are the same in the sense of reorder_items() and the
FixtureDef cache. The key is guaranteed to be hashable and comparable.
It's not intended for printing and therefore not ASCII escaped.
"""
for idx, parameterset in enumerate(self.parametersets):
if parameterset.id is not None:
# ID provided directly - pytest.param(..., id="...")
yield {argname: parameterset.id for argname in self.argnames}
elif self.ids and idx < len(self.ids) and self.ids[idx] is not None:
# ID provided in the IDs list - parametrize(..., ids=[...]).
yield {argname: self.ids[idx] for argname in self.argnames}
else:
# ID not provided - generate it.
yield self._parameter_keys_from_parameterset(parameterset, idx)

def _resolve_ids(self) -> Iterable[str]:
"""Resolve IDs for all ParameterSets (may contain duplicates)."""
for idx, parameterset in enumerate(self.parametersets):
Expand All @@ -994,6 +1048,20 @@ def _resolve_ids(self) -> Iterable[str]:
for val, argname in zip(parameterset.values, self.argnames)
)

def _parameter_keys_from_parameterset(
self, parameterset: ParameterSet, idx: int
) -> Dict[str, Hashable]:
"""Make parameter keys for all parameters in a ParameterSet."""
param_keys: Dict[str, Hashable] = {}
for val, argname in zip(parameterset.values, self.argnames):
evaluated_id = self._idval_from_function(val, argname, idx)
if evaluated_id is not None:
param_keys[argname] = evaluated_id
else:
# Wrapping ensures val becomes comparable and hashable.
param_keys[argname] = SafeHashWrapper(val)
return param_keys

def _idval(self, val: object, argname: str, idx: int) -> str:
"""Make an ID for a parameter in a ParameterSet."""
idval = self._idval_from_function(val, argname, idx)
Expand Down Expand Up @@ -1078,6 +1146,8 @@ class CallSpec2:
# arg name -> arg value which will be passed to a fixture of the same name
# (indirect parametrization).
params: Dict[str, object] = attr.Factory(dict)
# arg name -> parameter key.
param_keys: Dict[str, Hashable] = attr.Factory(dict)
# arg name -> arg index.
indices: Dict[str, int] = attr.Factory(dict)
# Used for sorting parametrized resources.
Expand All @@ -1097,9 +1167,12 @@ def setmulti(
marks: Iterable[Union[Mark, MarkDecorator]],
scope: Scope,
param_index: int,
param_set_keys: Dict[str, Hashable],
) -> "CallSpec2":
"""Extend an existing callspec with new parameters during multiple invocation of Metafunc.parametrize."""
funcargs = self.funcargs.copy()
params = self.params.copy()
param_keys = self.param_keys.copy()
indices = self.indices.copy()
arg2scope = self._arg2scope.copy()
for arg, val in zip(argnames, valset):
Expand All @@ -1113,10 +1186,12 @@ def setmulti(
else:
assert_never(valtype_for_arg)
indices[arg] = param_index
param_keys[arg] = param_set_keys[arg]
arg2scope[arg] = scope
return CallSpec2(
funcargs=funcargs,
params=params,
param_keys=param_keys,
arg2scope=arg2scope,
indices=indices,
idlist=[*self._idlist, id],
Expand Down Expand Up @@ -1286,7 +1361,7 @@ def parametrize(
if generated_ids is not None:
ids = generated_ids

ids = self._resolve_parameter_set_ids(
ids, parameters_keys = self._resolve_parameter_set_ids(
argnames, ids, parametersets, nodeid=self.definition.nodeid
)

Expand All @@ -1299,17 +1374,18 @@ def parametrize(
# of all calls.
newcalls = []
for callspec in self._calls or [CallSpec2()]:
for param_index, (param_id, param_set) in enumerate(
zip(ids, parametersets)
for param_index, (param_id, parameterset, param_set_keys) in enumerate(
zip(ids, parametersets, parameters_keys)
):
newcallspec = callspec.setmulti(
valtypes=arg_values_types,
argnames=argnames,
valset=param_set.values,
valset=parameterset.values,
id=param_id,
marks=param_set.marks,
marks=parameterset.marks,
scope=scope_,
param_index=param_index,
param_set_keys=param_set_keys,
)
newcalls.append(newcallspec)
self._calls = newcalls
Expand All @@ -1325,9 +1401,8 @@ def _resolve_parameter_set_ids(
],
parametersets: Sequence[ParameterSet],
nodeid: str,
) -> List[str]:
) -> Tuple[List[str], List[Dict[str, Hashable]]]:
"""Resolve the actual ids for the given parameter sets.

:param argnames:
Argument names passed to ``parametrize()``.
:param ids:
Expand All @@ -1339,7 +1414,9 @@ def _resolve_parameter_set_ids(
The nodeid of the definition item that generated this
parametrization.
:returns:
List with ids for each parameter set given.
Tuple, where
1st entry is a list with ids for each parameter set given used to name test invocations, and
2nd entry is a list with keys to support distinction of parameters to support fixture reuse.
"""
if ids is None:
idfn = None
Expand All @@ -1353,7 +1430,9 @@ def _resolve_parameter_set_ids(
id_maker = IdMaker(
argnames, parametersets, idfn, ids_, self.config, nodeid=nodeid
)
return id_maker.make_unique_parameterset_ids()
return id_maker.make_unique_parameterset_ids(), list(
id_maker.make_parameter_keys()
)

def _validate_ids(
self,
Expand Down
4 changes: 2 additions & 2 deletions testing/example_scripts/issue_519.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def checked_order():
assert order == [
("issue_519.py", "fix1", "arg1v1"),
("test_one[arg1v1-arg2v1]", "fix2", "arg2v1"),
("test_two[arg1v1-arg2v1]", "fix2", "arg2v1"),
("test_one[arg1v1-arg2v2]", "fix2", "arg2v2"),
("test_two[arg1v1-arg2v1]", "fix2", "arg2v1"),
("test_two[arg1v1-arg2v2]", "fix2", "arg2v2"),
("issue_519.py", "fix1", "arg1v2"),
("test_one[arg1v2-arg2v1]", "fix2", "arg2v1"),
("test_two[arg1v2-arg2v1]", "fix2", "arg2v1"),
("test_one[arg1v2-arg2v2]", "fix2", "arg2v2"),
("test_two[arg1v2-arg2v1]", "fix2", "arg2v1"),
("test_two[arg1v2-arg2v2]", "fix2", "arg2v2"),
]

Expand Down
53 changes: 53 additions & 0 deletions testing/python/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,59 @@ def test2(no_eq):
result = pytester.runpytest()
result.stdout.fnmatch_lines(["*4 passed*"])

@pytest.mark.parametrize(
("parametrize1", "parametrize2"),
[
(
'"fix", [1, 2], indirect=True',
'"fix", [2, 1], indirect=True',
),
(
'"fix", [1, pytest.param({"data": 2}, id="2")], indirect=True',
'"fix", [pytest.param({"data": 2}, id="2"), 1], indirect=True',
),
(
'"fix", [{"data": 1}, {"data": 2}], indirect=True, ids=lambda d: MyEnum(d["data"])',
'"fix", [{"data": 2}, {"data": 1}], indirect=True, ids=lambda d: MyEnum(d["data"])',
),
(
'"fix", [{"data": 1}, {"data": 2}], indirect=True, ids=[1, "two"]',
'"fix", [{"data": 2}, {"data": 1}], indirect=True, ids=["two", 1]',
),
],
)
def test_reorder_and_cache(
self, pytester: Pytester, parametrize1, parametrize2
) -> None:
"""Test optimization for minimal setup/teardown with indirectly parametrized fixtures. See #8914, #9420."""
pytester.makepyfile(
f"""
import pytest
from enum import Enum
class MyEnum(Enum):
Id1 = 1
Id2 = 2
@pytest.fixture(scope="session")
def fix(request):
value = request.param["data"] if isinstance(request.param, dict) else request.param
print(f'prepare foo-%s' % value)
yield value
print(f'teardown foo-%s' % value)
@pytest.mark.parametrize({parametrize1})
def test1(fix):
pass
@pytest.mark.parametrize({parametrize2})
def test2(fix):
pass
"""
)
result = pytester.runpytest("-s")
output = result.stdout.str()
assert output.count("prepare foo-1") == 1
assert output.count("prepare foo-2") == 1
assert output.count("teardown foo-1") == 1
assert output.count("teardown foo-2") == 1

def test_funcarg_parametrized_and_used_twice(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
Expand Down