Skip to content

Draft proposal for ExpectedExceptionGroup #11656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 222 additions & 13 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import pprint
import re
import sys
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
Expand All @@ -10,6 +12,8 @@
from typing import cast
from typing import ContextManager
from typing import final
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
Expand All @@ -28,6 +32,10 @@

if TYPE_CHECKING:
from numpy import ndarray
from typing_extensions import TypeAlias, TypeGuard

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup


def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
Expand Down Expand Up @@ -780,6 +788,149 @@ def _as_numpy_array(obj: object) -> Optional["ndarray"]:
# builtin pytest.raises helper

E = TypeVar("E", bound=BaseException)
E2 = TypeVar("E2", bound=BaseException)


class Matcher(Generic[E]):
def __init__(
self,
exception_type: Optional[Type[E]] = None,
match: Optional[Union[str, Pattern[str]]] = None,
check: Optional[Callable[[E], bool]] = None,
):
if exception_type is None and match is None and check is None:
raise ValueError("You must specify at least one parameter to match on.")
self.exception_type = exception_type
self.match = match
self.check = check

def matches(self, exception: E) -> "TypeGuard[E]":
if self.exception_type is not None and not isinstance(
exception, self.exception_type
):
return False
if self.match is not None and not re.search(self.match, str(exception)):
return False
if self.check is not None and not self.check(exception):
return False
return True


# TODO: rename if kept, EEE[E] looks like gibberish
EEE: "TypeAlias" = Union[Matcher[E], Type[E], "ExpectedExceptionGroup[E]"]

if TYPE_CHECKING:
SuperClass = BaseExceptionGroup
else:
SuperClass = Generic


# it's unclear if
# `ExpectedExceptionGroup(ValueError, strict=False).matches(ValueError())`
# should return True. It matches the behaviour of expect*, but is maybe better handled
# by the end user doing pytest.raises((ValueError, ExpectedExceptionGroup(ValueError)))
@final
class ExpectedExceptionGroup(SuperClass[E]):
# TODO: overload to disallow nested exceptiongroup with strict=False
# @overload
# def __init__(self, exceptions: Union[Matcher[E], Type[E]], *args: Union[Matcher[E],
# Type[E]], strict: Literal[False]): ...

# @overload
# def __init__(self, exceptions: EEE[E], *args: EEE[E], strict: bool = True): ...

def __init__(
self,
exceptions: Union[Type[E], E, Matcher[E]],
*args: Union[Type[E], E, Matcher[E]],
strict: bool = True,
):
# could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None`
self.expected_exceptions = (exceptions, *args)
self.strict = strict

for exc in self.expected_exceptions:
if not isinstance(exc, (Matcher, ExpectedExceptionGroup)) and not (
isinstance(exc, type) and issubclass(exc, BaseException)
):
raise ValueError(
"Invalid argument {exc} must be exception type, Matcher, or ExpectedExceptionGroup."
)
if isinstance(exc, ExpectedExceptionGroup) and not strict:
raise ValueError(
"You cannot specify a nested structure inside an ExpectedExceptionGroup with strict=False"
)

def _unroll_exceptions(
self, exceptions: Iterable[BaseException]
) -> Iterable[BaseException]:
res: list[BaseException] = []
for exc in exceptions:
if isinstance(exc, BaseExceptionGroup):
res.extend(self._unroll_exceptions(exc.exceptions))

else:
res.append(exc)
return res

def matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
if exc_val is None:
return False
if not isinstance(exc_val, BaseExceptionGroup):
return False
if not len(exc_val.exceptions) == len(self.expected_exceptions):
return False
remaining_exceptions = list(self.expected_exceptions)
actual_exceptions: Iterable[BaseException] = exc_val.exceptions
if not self.strict:
actual_exceptions = self._unroll_exceptions(actual_exceptions)

# it should be possible to get ExpectedExceptionGroup.matches typed so as not to
# need these type: ignores, but I'm not sure that's possible while also having it
# transparent for the end user.
for e in actual_exceptions:
for rem_e in remaining_exceptions:
# TODO: how to print string diff on mismatch?
# Probably accumulate them, and then if fail, print them
# Further QoL would be to print how the exception structure differs on non-match
if (
(isinstance(rem_e, type) and isinstance(e, rem_e))
or (
isinstance(e, BaseExceptionGroup)
and isinstance(rem_e, ExpectedExceptionGroup)
and rem_e.matches(e)
)
or (
isinstance(rem_e, Matcher)
and rem_e.matches(e) # type: ignore[arg-type]
)
):
remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
break
else:
return False
return True

# def __str__(self) -> str:
# return f"ExceptionGroup{self.expected_exceptions}"
# str(tuple(...)) seems to call repr
def __repr__(self) -> str:
# TODO: [Base]ExceptionGroup
return f"ExceptionGroup{self.expected_exceptions}"


@overload
def raises(
expected_exception: Union[
ExpectedExceptionGroup[E], Tuple[ExpectedExceptionGroup[E], ...]
],
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "RaisesContext[ExpectedExceptionGroup[E]]":
...


@overload
Expand All @@ -791,6 +942,17 @@ def raises(
...


#
#
# @overload
# def raises(
# expected_exception: Tuple[Union[Type[E], ExpectedExceptionGroup[E2]], ...],
# *,
# match: Optional[Union[str, Pattern[str]]] = ...,
# ) -> "RaisesContext[Union[E, BaseExceptionGroup[E2]]]":
# ...


@overload
def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
Expand All @@ -801,9 +963,20 @@ def raises( # noqa: F811
...


def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]:
def raises(
expected_exception: Union[
Type[E],
ExpectedExceptionGroup[E2],
Tuple[Union[Type[E], ExpectedExceptionGroup[E2]], ...],
],
*args: Any,
**kwargs: Any,
) -> Union[
"RaisesContext[E]",
"RaisesContext[BaseExceptionGroup[E2]]",
"RaisesContext[Union[E, BaseExceptionGroup[E2]]]",
_pytest._code.ExceptionInfo[E],
]:
r"""Assert that a code block/function call raises an exception type, or one of its subclasses.

:param typing.Type[E] | typing.Tuple[typing.Type[E], ...] expected_exception:
Expand Down Expand Up @@ -952,13 +1125,20 @@ def raises( # noqa: F811
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)
if isinstance(expected_exception, type):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,)
if isinstance(expected_exception, (type, ExpectedExceptionGroup)):
expected_exception_tuple: Tuple[
Union[Type[E], ExpectedExceptionGroup[E2]], ...
] = (expected_exception,)
else:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
expected_exception_tuple = expected_exception
for exc in expected_exception_tuple:
if (
not isinstance(exc, type) or not issubclass(exc, BaseException)
) and not isinstance(exc, ExpectedExceptionGroup):
msg = ( # type: ignore[unreachable]
"expected exception must be a BaseException "
"type or ExpectedExceptionGroup instance, not {}"
)
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))

Expand All @@ -971,14 +1151,22 @@ def raises( # noqa: F811
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
# the ExpectedExceptionGroup -> BaseExceptionGroup swap necessitates an ignore
return RaisesContext(expected_exception, message, match) # type: ignore[misc]
else:
func = args[0]

for exc in expected_exception_tuple:
if isinstance(exc, ExpectedExceptionGroup):
raise TypeError(
"Only contextmanager form is supported for ExpectedExceptionGroup"
)

if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
except expected_exception as e: # type: ignore[misc] # TypeError raised for any ExpectedExceptionGroup
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)

Expand All @@ -987,11 +1175,14 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore


EE: "TypeAlias" = Union[Type[E], "ExpectedExceptionGroup[E]"]


@final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(
self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
expected_exception: Union[EE[E], Tuple[EE[E], ...]],
message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None,
) -> None:
Expand All @@ -1014,8 +1205,26 @@ def __exit__(
if exc_type is None:
fail(self.message)
assert self.excinfo is not None
if not issubclass(exc_type, self.expected_exception):

if isinstance(self.expected_exception, ExpectedExceptionGroup):
if not self.expected_exception.matches(exc_val):
return False
elif isinstance(self.expected_exception, tuple):
for expected_exc in self.expected_exception:
if (
isinstance(expected_exc, ExpectedExceptionGroup)
and expected_exc.matches(exc_val)
) or (
isinstance(expected_exc, type)
and issubclass(exc_type, expected_exc)
):
break
else: # pragma: no cover
# this would've been caught on initialization of pytest.raises()
return False
elif not issubclass(exc_type, self.expected_exception):
return False

# Cast to narrow the exception type now that it's verified.
exc_info = cast(Tuple[Type[E], E, TracebackType], (exc_type, exc_val, exc_tb))
self.excinfo.fill_unfilled(exc_info)
Expand Down
Loading