Skip to content

Commit

Permalink
feat: support keyword arguments in marker expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
lovetheguitar committed Jun 21, 2024
1 parent f426c0b commit 08c95e3
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 17 deletions.
35 changes: 28 additions & 7 deletions src/_pytest/mark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Generic mechanism for marking and selecting python functions."""

import collections
import dataclasses
from typing import AbstractSet
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -181,7 +183,9 @@ def from_item(cls, item: "Item") -> "KeywordMatcher":

return cls(mapped_names)

def __call__(self, subname: str) -> bool:
def __call__(self, subname: str, /, **kwargs: object) -> bool:
if kwargs:
raise UsageError("Keyword expressions do not support call parameters.")
subname = subname.lower()
names = (name.lower() for name in self._names)

Expand Down Expand Up @@ -211,24 +215,41 @@ def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
items[:] = remaining


NOT_NONE_SENTINEL = object()


@dataclasses.dataclass
class MarkMatcher:
"""A matcher for markers which are present.
Tries to match on any marker names, attached to the given colitem.
"""

__slots__ = ("own_mark_names",)
__slots__ = ("own_mark_name_mapping",)

own_mark_names: AbstractSet[str]
own_mark_name_mapping: Dict[str, List["Mark"]]

@classmethod
def from_item(cls, item: "Item") -> "MarkMatcher":
mark_names = {mark.name for mark in item.iter_markers()}
return cls(mark_names)
mark_name_mapping = collections.defaultdict(list)
for mark in item.iter_markers():
mark_name_mapping[mark.name].append(mark)
return cls(mark_name_mapping)

def __call__(self, name: str, /, **kwargs: object) -> bool:
if not (matches := self.own_mark_name_mapping.get(name, [])):
return False

if not kwargs:
return True

def __call__(self, name: str) -> bool:
return name in self.own_mark_names
for mark in matches:
if all(
mark.kwargs.get(k, NOT_NONE_SENTINEL) == v for k, v in kwargs.items()
):
return True

return False


def deselect_by_mark(items: "List[Item]", config: Config) -> None:
Expand Down
112 changes: 103 additions & 9 deletions src/_pytest/mark/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
expression: expr? EOF
expr: and_expr ('or' and_expr)*
and_expr: not_expr ('and' not_expr)*
not_expr: 'not' not_expr | '(' expr ')' | ident
not_expr: 'not' not_expr | '(' expr ')' | ident ( '(' name '=' value ( ', ' name '=' value )* ')')*
ident: (\w|:|\+|-|\.|\[|\]|\\|/)+
The semantics are:
Expand All @@ -18,14 +19,17 @@
import ast
import dataclasses
import enum
import keyword
import re
import types
from typing import Callable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import NoReturn
from typing import Optional
from typing import Protocol
from typing import Sequence
from typing import Union


__all__ = [
Expand All @@ -42,6 +46,9 @@ class TokenType(enum.Enum):
NOT = "not"
IDENT = "identifier"
EOF = "end of input"
EQUAL = "="
STRING = "str"
COMMA = ","


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -85,6 +92,27 @@ def lex(self, input: str) -> Iterator[Token]:
elif input[pos] == ")":
yield Token(TokenType.RPAREN, ")", pos)
pos += 1
elif input[pos] == "=":
yield Token(TokenType.EQUAL, "=", pos)
pos += 1
elif input[pos] == ",":
yield Token(TokenType.COMMA, ",", pos)
pos += 1
elif (quote_char := input[pos]) == "'" or input[pos] == '"':
quote_position = input[pos + 1 :].find(quote_char)
if quote_position == -1:
raise ParseError(
pos + 1,
f'closing quote "{quote_char}" is missing',
)
value = input[pos : pos + 2 + quote_position]
if "\\" in value:
raise ParseError(
pos + 1,
"escaping not supported in marker expression",
)
yield Token(TokenType.STRING, value, pos)
pos += len(value)
else:
match = re.match(r"(:?\w|:|\+|-|\.|\[|\]|\\|/)+", input[pos:])
if match:
Expand Down Expand Up @@ -165,18 +193,84 @@ def not_expr(s: Scanner) -> ast.expr:
return ret
ident = s.accept(TokenType.IDENT)
if ident:
return ast.Name(IDENT_PREFIX + ident.value, ast.Load())
name = ast.Name(IDENT_PREFIX + ident.value, ast.Load())
if s.accept(TokenType.LPAREN):
ret = ast.Call(func=name, args=[], keywords=all_kwargs(s))
s.accept(TokenType.RPAREN, reject=True)
else:
ret = name
return ret

s.reject((TokenType.NOT, TokenType.LPAREN, TokenType.IDENT))


class MatcherAdapter(Mapping[str, bool]):
BUILTIN_MATCHERS = {"True": True, "False": False, "None": None}


def single_kwarg(s: Scanner) -> ast.keyword:
keyword_name = s.accept(TokenType.IDENT, reject=True)
assert keyword_name is not None # for mypy
if not keyword_name.value.isidentifier() or keyword.iskeyword(keyword_name.value):
raise ParseError(
keyword_name.pos + 1,
f'unexpected character/s "{keyword_name.value}"',
)
s.accept(TokenType.EQUAL, reject=True)

if value_token := s.accept(TokenType.STRING):
value: Union[str, int, bool, None] = value_token.value[1:-1] # strip quotes
else:
value_token = s.accept(TokenType.IDENT, reject=True)
assert value_token is not None # for mypy
if (
(number := value_token.value).isdigit()
or number.startswith("-")
and number[1:].isdigit()
):
value = int(number)
elif value_token.value in BUILTIN_MATCHERS:
value = BUILTIN_MATCHERS[value_token.value]
else:
raise ParseError(
value_token.pos + 1,
f'unexpected character/s "{value_token.value}"',
)

ret = ast.keyword(keyword_name.value, ast.Constant(value))
return ret


def all_kwargs(s: Scanner) -> List[ast.keyword]:
ret = [single_kwarg(s)]
while s.accept(TokenType.COMMA):
ret.append(single_kwarg(s))
return ret


class MatcherCall(Protocol):
def __call__(self, name: str, /, **kwargs: object) -> bool: ...


@dataclasses.dataclass
class MatcherNameAdapter:
matcher: MatcherCall
name: str

def __bool__(self) -> bool:
return self.matcher(self.name)

def __call__(self, **kwargs: object) -> bool:
return self.matcher(self.name, **kwargs)


class MatcherAdapter(Mapping[str, MatcherNameAdapter]):
"""Adapts a matcher function to a locals mapping as required by eval()."""

def __init__(self, matcher: Callable[[str], bool]) -> None:
def __init__(self, matcher: MatcherCall) -> None:
self.matcher = matcher

def __getitem__(self, key: str) -> bool:
return self.matcher(key[len(IDENT_PREFIX) :])
def __getitem__(self, key: str) -> MatcherNameAdapter:
return MatcherNameAdapter(matcher=self.matcher, name=key[len(IDENT_PREFIX) :])

def __iter__(self) -> Iterator[str]:
raise NotImplementedError()
Expand Down Expand Up @@ -210,7 +304,7 @@ def compile(self, input: str) -> "Expression":
)
return Expression(code)

def evaluate(self, matcher: Callable[[str], bool]) -> bool:
def evaluate(self, matcher: MatcherCall) -> bool:
"""Evaluate the match expression.
:param matcher:
Expand All @@ -219,5 +313,5 @@ def evaluate(self, matcher: Callable[[str], bool]) -> bool:
:returns: Whether the expression matches or not.
"""
ret: bool = eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher))
ret: bool = bool(eval(self.code, {"__builtins__": {}}, MatcherAdapter(matcher)))
return ret
48 changes: 48 additions & 0 deletions testing/test_mark.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,54 @@ def test_two():
assert passed_str == expected_passed


@pytest.mark.parametrize(
("expr", "expected_passed"),
[ # TODO: improve/sort out
("car(color='red')", ["test_one"]),
("car(color='red') or car(color='blue')", ["test_one", "test_two"]),
("car and not car(temp=5)", ["test_one", "test_three"]),
("car(temp=4)", ["test_one"]),
("car(temp=4) or car(temp=5)", ["test_one", "test_two"]),
("car(temp=4) and car(temp=5)", []),
("car(temp=-5)", ["test_three"]),
("car(ac=True)", ["test_one"]),
("car(ac=False)", ["test_two"]),
("car(ac=None)", ["test_three"]), # test NOT_NONE_SENTINEL
],
ids=str,
)
def test_mark_option_with_kwargs(
expr: str, expected_passed: List[Optional[str]], pytester: Pytester
) -> None:
pytester.makepyfile(
"""
import pytest
@pytest.mark.car
@pytest.mark.car(ac=True)
@pytest.mark.car(temp=4)
@pytest.mark.car(color="red")
def test_one():
pass
@pytest.mark.car
@pytest.mark.car(ac=False)
@pytest.mark.car(temp=5)
@pytest.mark.car(color="blue")
def test_two():
pass
@pytest.mark.car
@pytest.mark.car(ac=None)
@pytest.mark.car(temp=-5)
def test_three():
pass
"""
)
rec = pytester.inline_run("-m", expr)
passed, skipped, fail = rec.listoutcomes()
passed_str = [x.nodeid.split("::")[-1] for x in passed]
assert passed_str == expected_passed


@pytest.mark.parametrize(
("expr", "expected_passed"),
[("interface", ["test_interface"]), ("not interface", ["test_nointer"])],
Expand Down
Loading

0 comments on commit 08c95e3

Please sign in to comment.