Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate CallSpec2 #6724

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 43 additions & 16 deletions src/_pytest/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import inspect
import os
import sys
import typing
import warnings
from collections import Counter
from collections import defaultdict
from collections.abc import Sequence
from functools import partial
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import Union
Expand All @@ -35,19 +38,24 @@
from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl
from _pytest.deprecated import FUNCARGNAMES
from _pytest.mark import MARK_GEN
from _pytest.mark import ParameterSet
from _pytest.mark.structures import get_unpacked_marks
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import normalize_mark_list
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.pathlib import parts
from _pytest.warning_types import PytestCollectionWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning

if TYPE_CHECKING:
from typing_extensions import Literal


def pyobj_property(name):
def get(self):
Expand Down Expand Up @@ -784,16 +792,17 @@ def hasnew(obj):


class CallSpec2:
def __init__(self, metafunc):
def __init__(self, metafunc: "Metafunc") -> None:
self.metafunc = metafunc
self.funcargs = {}
self._idlist = []
self.params = {}
self._arg2scopenum = {} # used for sorting parametrized resources
self.marks = []
self.indices = {}

def copy(self):
self.funcargs = {} # type: Dict[str, object]
self._idlist = [] # type: List[str]
self.params = {} # type: Dict[str, object]
# Used for sorting parametrized resources.
self._arg2scopenum = {} # type: Dict[str, int]
self.marks = [] # type: List[Mark]
self.indices = {} # type: Dict[str, int]

def copy(self) -> "CallSpec2":
cs = CallSpec2(self.metafunc)
cs.funcargs.update(self.funcargs)
cs.params.update(self.params)
Expand All @@ -803,25 +812,39 @@ def copy(self):
cs._idlist = list(self._idlist)
return cs

def _checkargnotcontained(self, arg):
def _checkargnotcontained(self, arg: str) -> None:
if arg in self.params or arg in self.funcargs:
raise ValueError("duplicate {!r}".format(arg))

def getparam(self, name):
def getparam(self, name: str) -> object:
try:
return self.params[name]
except KeyError:
raise ValueError(name)

@property
def id(self):
def id(self) -> str:
return "-".join(map(str, self._idlist))

def setmulti2(self, valtypes, argnames, valset, id, marks, scopenum, param_index):
def setmulti2(
self,
valtypes: "Mapping[str, Literal['params', 'funcargs']]",
argnames: typing.Sequence[str],
valset: Iterable[object],
id: str,
marks: Iterable[Union[Mark, MarkDecorator]],
scopenum: int,
param_index: int,
) -> None:
for arg, val in zip(argnames, valset):
self._checkargnotcontained(arg)
valtype_for_arg = valtypes[arg]
getattr(self, valtype_for_arg)[arg] = val
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to tweak the code here a bit, mypy can't understand getattr. I also think it's better to avoid it in general.

To show that it is safe, I used Literal here and in the source below Metafunc._resolve_arg_value_types.

if valtype_for_arg == "params":
self.params[arg] = val
elif valtype_for_arg == "funcargs":
self.funcargs[arg] = val
else: # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Small tip, no need to change)

When I encounter these situations where we have a fixed number of if/elifs and an else clause as an error, I opt to have an explicit else: clause for the last choice:

            if valtype_for_arg == "params":
                self.params[arg] = val
            elif valtype_for_arg == "funcargs":
                self.funcargs[arg] = val
            else:  # pragma: no cover
                assert False, "Unhandled valtype for arg: {}".format(valtype_for_arg)

Becomes:

            if valtype_for_arg == "params":
                self.params[arg] = val
            else:
                assert valtype_for_arg == "funcargs", "Unhandled valtype for arg: {}".format(valtype_for_arg)
                self.funcargs[arg] = val

It is shorter and avoids having to write a pragma: nocover 😁

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

There is actually a reason why I prefer the form I wrote -- it allows to perform exhaustiveness checking. Consider for example that a new value is added to the valtype_for_arg Literal type. Then ideally, mypy would raise an error if I didn't remember to add a case for it in the conditional.

For enums, this is already possible, see python/mypy#5818. In my own code I rely on this greatly.

For Literals, this is not yet supported by mypy, but hopefully will be in the future: python/mypy#6366

Regarding to no cover, WDYT about adding assert False to the coverage exclude_lines list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise NotimplementedError()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an assert is more appropriate than NotimplementedError for this. assert means "internal assumption broken", while NotimplementedError might imply "not implemented intentionally" which is not the case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it more like "might be implemented later", like you've decribed it yourself ("a new value is added").
But then maybe lets add ^\s*raise AssertionError\b to the exclude list? I think it is good to be explicit here, to not ignore any assert 0, but your assert False might work for this also already.

assert False, "Unhandled valtype for arg: {}".format(valtype_for_arg)
self.indices[arg] = param_index
self._arg2scopenum[arg] = scopenum
self._idlist.append(id)
Expand Down Expand Up @@ -1042,7 +1065,9 @@ def _validate_ids(self, ids, parameters, func_name):
)
return new_ids

def _resolve_arg_value_types(self, argnames: List[str], indirect) -> Dict[str, str]:
def _resolve_arg_value_types(
self, argnames: List[str], indirect
) -> Dict[str, "Literal['params', 'funcargs']"]:
"""Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg"
to the function, based on the ``indirect`` parameter of the parametrized() call.

Expand All @@ -1054,7 +1079,9 @@ def _resolve_arg_value_types(self, argnames: List[str], indirect) -> Dict[str, s
* "funcargs" if the argname should be a parameter to the parametrized test function.
"""
if isinstance(indirect, bool):
valtypes = dict.fromkeys(argnames, "params" if indirect else "funcargs")
valtypes = dict.fromkeys(
argnames, "params" if indirect else "funcargs"
) # type: Dict[str, Literal["params", "funcargs"]]
elif isinstance(indirect, Sequence):
valtypes = dict.fromkeys(argnames, "funcargs")
for arg in indirect:
Expand Down