Skip to content

Type annotate ParameterSet & param() #6726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/_pytest/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
python version compatibility code
"""
import enum
import functools
import inspect
import io
Expand Down Expand Up @@ -35,13 +36,20 @@

if TYPE_CHECKING:
from typing import Type # noqa: F401 (used in type string)
from typing_extensions import Final


_T = TypeVar("_T")
_S = TypeVar("_S")


NOTSET = object()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

object() doesn't produce a singleton on the type level, so I had to switch it to an idiom that does.

# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET = NotSetType.token # type: Final # noqa: E305
# fmt: on

MODULE_NOT_FOUND_ERROR = (
"ModuleNotFoundError" if sys.version_info[:2] >= (3, 6) else "ImportError"
Expand Down
12 changes: 10 additions & 2 deletions src/_pytest/mark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
""" generic mechanism for marking and selecting python functions. """
import typing
from typing import Optional
from typing import Union

from .legacy import matchkeyword
from .legacy import matchmark
from .structures import EMPTY_PARAMETERSET_OPTION
Expand All @@ -14,7 +18,11 @@
__all__ = ["Mark", "MarkDecorator", "MarkGenerator", "get_empty_parameterset_mark"]


def param(*values, **kw):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expanded the kwargs so I can annotate them and so that the types show up in the docs.

def param(
*values: object,
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
id: Optional[str] = None
) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.

Expand All @@ -31,7 +39,7 @@ def test_eval(test_input, expected):
:keyword marks: a single mark or a list of marks to be applied to this parameter set.
:keyword str id: the id to attribute to this parameter set.
"""
return ParameterSet.param(*values, **kw)
return ParameterSet.param(*values, marks=marks, id=id)


def pytest_addoption(parser):
Expand Down
85 changes: 69 additions & 16 deletions src/_pytest/mark/structures.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,44 @@
import collections.abc
import inspect
import typing
import warnings
from collections import namedtuple
from collections.abc import MutableMapping
from typing import Iterable
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union

import attr

from .._code.source import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
from ..compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning

if TYPE_CHECKING:
from _pytest.python import FunctionDefinition


EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"


def istestfunc(func):
def istestfunc(func) -> bool:
return (
hasattr(func, "__call__")
and getattr(func, "__name__", "<lambda>") != "<lambda>"
)


def get_empty_parameterset_mark(config, argnames, func):
def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func
) -> "MarkDecorator":
from ..nodes import Collector

requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
Expand All @@ -49,16 +61,33 @@ def get_empty_parameterset_mark(config, argnames, func):
fs,
lineno,
)
return mark(reason=reason)


class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
# Type ignored because MarkDecorator.__call__() is a bit tough to
# annotate ATM.
return mark(reason=reason) # type: ignore[no-any-return] # noqa: F723


class ParameterSet(
NamedTuple(
"ParameterSet",
[
("values", Sequence[Union[object, NotSetType]]),
("marks", "typing.Collection[Union[MarkDecorator, Mark]]"),
("id", Optional[str]),
],
)
):
@classmethod
def param(cls, *values, marks=(), id=None):
def param(
cls,
*values: object,
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
id: Optional[str] = None
) -> "ParameterSet":
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
assert isinstance(marks, (tuple, list, set))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generalized this a bit, otherwise type annotation is very ugly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But turns out py3.5 doesn't have collections.abc.Collection, so I changed the runtime check to collections.abc.Sequence | set, but left the type as Collection. So there will be a slight mismatch until we can drop py3.5.

# TODO(py36): Change to collections.abc.Collection.
assert isinstance(marks, (collections.abc.Sequence, set))

if id is not None:
if not isinstance(id, str):
Expand All @@ -69,7 +98,11 @@ def param(cls, *values, marks=(), id=None):
return cls(values, marks, id)

@classmethod
def extract_from(cls, parameterset, force_tuple=False):
def extract_from(
cls,
parameterset: Union["ParameterSet", Sequence[object], object],
force_tuple: bool = False,
) -> "ParameterSet":
"""
:param parameterset:
a legacy style parameterset that may or may not be a tuple,
Expand All @@ -85,10 +118,20 @@ def extract_from(cls, parameterset, force_tuple=False):
if force_tuple:
return cls.param(parameterset)
else:
return cls(parameterset, marks=[], id=None)
# TODO: Refactor to fix this type-ignore. Currently the following
# type-checks but crashes:
#
# @pytest.mark.parametrize(('x', 'y'), [1, 2])
# def test_foo(x, y): pass
return cls(parameterset, marks=[], id=None) # type: ignore[arg-type] # noqa: F821

@staticmethod
def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
def _parse_parametrize_args(
argnames: Union[str, List[str], Tuple[str, ...]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
*args,
**kwargs
) -> Tuple[Union[List[str], Tuple[str, ...]], bool]:
if not isinstance(argnames, (tuple, list)):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1
Expand All @@ -97,13 +140,23 @@ def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
return argnames, force_tuple

@staticmethod
def _parse_parametrize_parameters(argvalues, force_tuple):
def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
force_tuple: bool,
) -> List["ParameterSet"]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]

@classmethod
def _for_parametrize(cls, argnames, argvalues, func, config, function_definition):
def _for_parametrize(
cls,
argnames: Union[str, List[str], Tuple[str, ...]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
func,
config: Config,
function_definition: "FunctionDefinition",
) -> Tuple[Union[List[str], Tuple[str, ...]], List["ParameterSet"]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
Expand Down Expand Up @@ -357,7 +410,7 @@ def __getattr__(self, name: str) -> MarkDecorator:
MARK_GEN = MarkGenerator()


class NodeKeywords(MutableMapping):
class NodeKeywords(collections.abc.MutableMapping):
def __init__(self, node):
self.node = node
self.parent = node.parent
Expand Down