diff --git a/src/typelib/binding.py b/src/typelib/binding.py new file mode 100644 index 0000000..f2ef192 --- /dev/null +++ b/src/typelib/binding.py @@ -0,0 +1,681 @@ +"""Utilities for automatic unmarshalling of inputs for a callable object's signature.""" + +from __future__ import annotations + +import abc +import dataclasses +import functools +import inspect +import typing as tp + +from typelib import classes, compat, inspection, unmarshal + +P = compat.ParamSpec("P") +R = tp.TypeVar("R") +BindingT = compat.TypeAliasType( + "BindingT", + "tp.MutableMapping[str | int, unmarshal.AbstractUnmarshaller]", +) + + +def bind(obj: tp.Callable[P, R]) -> BoundRoutine[P, R]: + """Create a type-enforced, bound routine for a callable object. + + Notes: + In contrast to :py:func:`~typelib.binding.wrap`, this function creates a new, + type-enforced :py:class:`~typelib.binding.BoundRoutine` instance. Rather than + masquerading as the given :py:param:`obj`, we encapsulate it in the routine + instance, which is more obvious and provides developers with the ability to + side-step type enforcement when it is deemed unnecessary, which should be + most of the time if your code is strongly typed and statically analyzed. + + Args: + obj: The callable object to bind. + """ + binding: AbstractBinding[P] = _get_binding(obj) + routine: BoundRoutine[P, R] = BoundRoutine( + call=obj, + binding=binding, + ) + return routine + + +def wrap(obj: tp.Callable[P, R]) -> tp.Callable[..., R]: + """Wrap a callable object for runtime type coercion of inputs. + + Notes: + If a class is given, we will attempt to wrap the init method. + + Warnings: + This is a useful feature. It is also very *surprising*. By wrapping a callable + in this decorator, we end up with *implicit* behavior that's not obvious to the + caller or a fellow developer. + + You're encouraged to prefer :py:func:`~typelib.binding.bind` for similar + functionality, less the implicit nature, especially when a class is given. + + Args: + obj: The callable object to wrap. + Maybe be a function, a callable class instance, or a class. + """ + + binding: AbstractBinding[P] = _get_binding(obj) + + if inspect.isclass(obj): + obj.__init__ = wrap(obj.__init__) + return obj + + @functools.wraps(obj) # type: ignore[arg-type] + def binding_wrapper(*args: tp.Any, __binding=binding, **kwargs: tp.Any) -> R: + bargs, bkwargs = __binding(args, kwargs) + return obj(*bargs, **bkwargs) + + return binding_wrapper + + +@classes.slotted(dict=False, weakref=True) +@dataclasses.dataclass +class BoundRoutine(tp.Generic[P, R]): + call: tp.Callable[P, R] + binding: AbstractBinding[P] + + def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> R: + bargs, bkwargs = self.binding(args=args, kwargs=kwargs) + return self.call(*bargs, **bkwargs) + + +@compat.cache +def _get_binding(obj: tp.Callable) -> AbstractBinding: + sig = inspection.cached_signature(obj) + params = sig.parameters + binding: BindingT = {} + has_pos_only = False + has_kwd_only = False + has_args = False + has_kwargs = False + has_pos_or_kwd = False + max_pos: int | None = None + varkwd: unmarshal.AbstractUnmarshaller | None = None + varpos: unmarshal.AbstractUnmarshaller | None = None + for i, (name, param) in enumerate(params.items()): + unmarshaller: unmarshal.AbstractUnmarshaller = unmarshal.unmarshaller( + param.annotation + ) + binding[name] = binding[i] = unmarshaller + has_kwd_only = has_kwd_only or param.kind == inspect.Parameter.KEYWORD_ONLY + has_pos_or_kwd = ( + has_pos_or_kwd or param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + ) + if param.kind == param.POSITIONAL_ONLY: + has_pos_only = True + max_pos = i + continue + if param.kind == param.VAR_KEYWORD: + has_kwargs = True + varkwd = unmarshaller + continue + if param.kind == param.VAR_POSITIONAL: + has_args = True + max_pos = i - 1 + varpos = unmarshaller + + truth = _Truth( + has_pos_only=has_pos_only, + has_kwd_only=has_kwd_only, + has_args=has_args, + has_kwargs=has_kwargs, + has_pos_or_kwd=has_pos_or_kwd, + ) + binding_cls = _BINDING_CLS_MATRIX[truth] + return binding_cls( + binding=binding, + signature=sig, + varkwd=varkwd, + varpos=varpos, + startpos=max_pos + 1 if max_pos is not None else None, + ) + + +class AbstractBinding(abc.ABC, tp.Generic[P]): + """The abstract base class for all type-enforced bindings. + + Notes: + "Bindings" are callables which leverage the type annotations in a signature to + unmarshal inputs. + + We differentiate each subclass based upon the possible combinations of parameter + kinds: + - Positional-only arguments + - Keyword-only arguments + - Positional-or-Keyword arguments + - Variable-positional arguments (`*args`) + - Variable-keyword arguments (`**kwargs`) + + This allows us to micro-optimize the call for each subclasses to exactly what is + necessary for the that combination, which can lead to a significant speedup at + runtime. + """ + + __slots__ = ("binding", "signature", "varkwd", "varpos", "startpos") + + def __init__( + self, + *, + signature: inspect.Signature, + binding: BindingT, + varkwd: unmarshal.AbstractUnmarshaller | None = None, + varpos: unmarshal.AbstractUnmarshaller | None = None, + startpos: int | None = None, + ): + """Constructor. + + Args: + signature: The signature for the binding. + binding: A mapping of parameter names and positions to unmarshallers. + This accounts for positional, keyword, or positional-or-keyword arguments. + varkwd: The unmarshaller for var-keyword arguments (`**kwargs`). + varpos: The unmarshaller for var-positional arguments (`*args`). + startpos: The start position of var-positional arguments (`*args`). + This accounts for the fact that var-positional comes after positional-only. + """ + self.signature = signature + self.binding = binding + self.varkwd = varkwd + self.varpos = varpos + self.startpos = startpos + + def __repr__(self): + return f"<{self.__class__.__name__}(signature={self.signature})>" + + @abc.abstractmethod + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + """Inspect the given :py:param:`args` and :py:param:`kwargs` and unmarshal them. + + Args: + args: The positional arguments. + kwargs: The keyword arguments. + """ + + +class AnyParamKindBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Split the supplied args into the positional and the var-args + # Implementation note: if there are positional args and var-args, + # positional args will always be first. + posargs = args[: self.startpos] + varargs = args[self.startpos :] + # Unmarshal the positional arguments. + umargs = ( + *(binding[i](v) if i in binding else v for i, v in enumerate(posargs)), + *(varpos(v) for v in varargs), + ) + # Unmarshal the keyword arguments. + umkwargs = {k: binding.get(k, varkwd)(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosArgsKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Split the supplied args into the positional and the var-args + # Implementation note: if there are positional args and var-args, + # positional args will always be first. + posargs = args[: self.startpos] + varargs = args[self.startpos :] + umargs = ( + *(binding[i](v) if i in binding else v for i, v in enumerate(posargs)), + *(varpos(v) for v in varargs), + ) + umkwargs = {k: varkwd(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosKwdKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the args + umargs = (*(binding[i](v) if i in binding else v for i, v in enumerate(args)),) + # Unmarshal the keyword arguments. + umkwargs = {k: binding.get(k, varkwd)(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosKwdArgsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + # Split the supplied args into the positional and the var-args + # Implementation note: if there are positional args and var-args, + # positional args will always be first. + posargs = args[: self.startpos] + varargs = args[self.startpos :] + # Unmarshal the positional arguments. + umargs = ( + *(binding[i](v) if i in binding else v for i, v in enumerate(posargs)), + *(varpos(v) for v in varargs), + ) + # Unmarshal the keyword arguments. + umkwargs = {k: binding[k](v) if k in binding else v for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the args + umargs = (*(binding[i](v) if i in binding else v for i, v in enumerate(args)),) + # Unmarshal the keyword arguments. + umkwargs = {k: varkwd(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosKwdBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + # Unmarshal the args + umargs = (*(binding[i](v) if i in binding else v for i, v in enumerate(args)),) + # Unmarshal the keyword arguments. + umkwargs = {k: binding[k](v) if k in binding else k for k, v in kwargs.items()} + return umargs, umkwargs + + +class PosArgsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + # Split the supplied args into the positional and the var-args + # Implementation note: if there are positional args and var-args, + # positional args will always be first. + posargs = args[: self.startpos] + varargs = args[self.startpos :] + umargs = ( + *(binding[i](v) if i in binding else v for i, v in enumerate(posargs)), + *(varpos(v) for v in varargs), + ) + return umargs, kwargs + + +class PosBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize key attributes + binding = self.binding + # Unmarshal the args + umargs = (*(binding[i](v) if i in binding else v for i, v in enumerate(args)),) + return umargs, kwargs + + +class KwdArgsKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the positional args. + umargs = (*(varpos(v) for v in args),) + # Unmarshal the keyword args. + umkwargs = {k: binding.get(k, varkwd)(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class KwdArgsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + binding = self.binding + varpos: unmarshal.AbstractUnmarshaller = self.varpos + # Unmarshal the positional arguments + umargs = (*(varpos(v) for v in args),) + # Unmarshal the keyword arguments. + umkwargs = {k: binding[k](v) if k in binding else k for k, v in kwargs.items()} + return umargs, umkwargs + + +class KwdKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + binding = self.binding + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the keyword arguments. + umkwargs = {k: binding.get(k, varkwd)(v) for k, v in kwargs.items()} + return args, umkwargs + + +class KwdBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + binding = self.binding + # Unmarshal the keyword arguments. + umkwargs = {k: binding[k](v) if k in binding else k for k, v in kwargs.items()} + return args, umkwargs + + +class ArgsKwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + varpos: unmarshal.AbstractUnmarshaller = self.varpos + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the positional arguments. + umargs = (*(varpos(v) for v in args),) + # Unmarshal the keyword arguments. + umkwargs = {k: varkwd(v) for k, v in kwargs.items()} + return umargs, umkwargs + + +class KwargsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + varkwd: unmarshal.AbstractUnmarshaller = self.varkwd + # Unmarshal the keyword arguments. + umkwargs = {k: varkwd(v) for k, v in kwargs.items()} + return args, umkwargs + + +class ArgsBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + varpos: unmarshal.AbstractUnmarshaller = self.varpos + # Unmarshal the positional arguments. + umargs = (*(varpos(v) for v in args),) + return umargs, kwargs + + +class PosOrKwdBinding(AbstractBinding[P], tp.Generic[P]): + def __call__( + self, args: tuple[tp.Any], kwargs: dict[str, tp.Any] + ) -> tuple[P.args, P.kwargs]: + # Localize the key attributes + binding = self.binding + # Unmarshal the positional args. + umargs = (*(binding[i](v) if i in binding else v for i, v in enumerate(args)),) + # Unmarshal the keyword arguments. + umkwargs = {k: binding[k](v) if k in binding else k for k, v in kwargs.items()} + return umargs, umkwargs + + +class _Truth(tp.NamedTuple): + """Internal truth entry for determining the binding algorithm. + + Attributes: + has_pos_only: Whether the signature contains positional-only arguments. + has_kwd_only: Whether the signature contains keyword-only arguments. + has_args: Whether the signature contains variadic-positional arguments. + has_kwargs: Whether the signature contains variadic-keyword arguments. + has_pos_or_kwd: Whether the signature contains positional-or-keyword arguments. + """ + + has_pos_only: bool = False + has_kwd_only: bool = False + has_args: bool = False + has_kwargs: bool = False + has_pos_or_kwd: bool = False + + +_BINDING_CLS_MATRIX: dict[_Truth, type[AbstractBinding]] = { + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=False, + ): PosOrKwdBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosOrKwdBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=False, + ): KwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=True, + ): PosKwdKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=False, + ): ArgsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosArgsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=False, + ): ArgsKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=False, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=True, + ): ArgsKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=False, + ): KwdBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosOrKwdBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=False, + ): KwdKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=True, + ): PosKwdKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=False, + ): KwdArgsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosOrKwdBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=False, + ): KwdArgsKwargsBinding, + _Truth( + has_pos_only=False, + has_kwd_only=True, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=True, + ): KwdArgsKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=False, + ): PosBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosOrKwdBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=False, + ): PosKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=True, + ): PosKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=False, + ): PosArgsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosArgsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=False, + ): PosArgsKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=False, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=True, + ): AnyParamKindBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=False, + ): PosKwdBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=False, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosKwdKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=False, + ): PosArgsKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=False, + has_kwargs=True, + has_pos_or_kwd=True, + ): PosKwdKwargsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=False, + ): PosKwdArgsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=True, + has_kwargs=False, + has_pos_or_kwd=True, + ): PosKwdArgsBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=False, + ): AnyParamKindBinding, + _Truth( + has_pos_only=True, + has_kwd_only=True, + has_args=True, + has_kwargs=True, + has_pos_or_kwd=True, + ): AnyParamKindBinding, +} diff --git a/src/typelib/compat.py b/src/typelib/compat.py index 3445c5f..65a6bde 100644 --- a/src/typelib/compat.py +++ b/src/typelib/compat.py @@ -20,6 +20,7 @@ "DATACLASS_NATIVE_SLOTS", "KW_ONLY", "lru_cache", + "cache", ) if TYPE_CHECKING: diff --git a/tests/unit/test_binding.py b/tests/unit/test_binding.py new file mode 100644 index 0000000..5b7aabe --- /dev/null +++ b/tests/unit/test_binding.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import collections +import inspect +from unittest import mock + +import pytest +from typelib import binding + + +@pytest.fixture( + params=[ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.VAR_KEYWORD, + ] +) +def one(request) -> inspect.Parameter: + return inspect.Parameter( + "one", + request.param, + annotation=int, + ) + + +@pytest.fixture( + params=[ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.KEYWORD_ONLY, + ] +) +def two(request) -> inspect.Parameter: + return inspect.Parameter( + "two", + request.param, + annotation=int, + ) + + +@pytest.fixture( + params=[ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] +) +def three(request) -> inspect.Parameter: + return inspect.Parameter( + "three", + request.param, + annotation=int, + ) + + +@pytest.fixture( + params=[ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] +) +def four(request) -> inspect.Parameter: + return inspect.Parameter( + "four", + request.param, + annotation=int, + ) + + +@pytest.fixture( + params=[ + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ] +) +def five(request) -> inspect.Parameter: + return inspect.Parameter( + "five", + request.param, + annotation=int, + ) + + +@pytest.fixture() +def given_signature(one, two, three, four, five): + params = sorted([one, two, three, four, five], key=lambda p: p.kind) + counts = collections.Counter(p.kind for p in params) + if ( + counts[inspect.Parameter.VAR_KEYWORD] > 1 + or counts[inspect.Parameter.VAR_POSITIONAL] > 1 + ): + pytest.skip("Impossible param combination.") + + return inspect.Signature(params) + + +@pytest.fixture() +def given_input(one, two, three, four, five): + inp = [] + kw_inp = {} + params = sorted([one, two, three, four, five], key=lambda p: p.kind) + for p in params: + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + inp.append("1") + elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + inp.append("2") + elif p.kind == inspect.Parameter.VAR_POSITIONAL: + inp.append("3") + elif p.kind == inspect.Parameter.KEYWORD_ONLY: + kw_inp.update({p.name: "4"}) + elif p.kind == inspect.Parameter.VAR_KEYWORD: + kw_inp.update({p.name: "5"}) + return tuple(inp), kw_inp + + +@pytest.fixture() +def expected_output(one, two, three, four, five): + params = sorted([one, two, three, four, five], key=lambda p: p.kind) + out = [] + kw_out = {} + for p in params: + if p.kind == inspect.Parameter.POSITIONAL_ONLY: + out.append(1) + elif p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + out.append(2) + elif p.kind == inspect.Parameter.VAR_POSITIONAL: + out.append(3) + elif p.kind == inspect.Parameter.KEYWORD_ONLY: + kw_out.update({p.name: 4}) + elif p.kind == inspect.Parameter.VAR_KEYWORD: + kw_out.update({p.name: 5}) + + return tuple(out), kw_out + + +def test_bind(given_signature, given_input, expected_output): + # Given + given_callable = mock.Mock( + __signature__=given_signature, side_effect=lambda *a, **kw: (a, kw) + ) + given_binding = binding.bind(given_callable) + given_args, given_kwargs = given_input + # When + output = given_binding(*given_args, **given_kwargs) + # Then + assert output == expected_output + + +def test_wrap(given_signature, given_input, expected_output): + # Given + given_callable = mock.Mock( + __signature__=given_signature, side_effect=lambda *a, **kw: (a, kw) + ) + given_binding = binding.wrap(given_callable) + given_args, given_kwargs = given_input + # When + output = given_binding(*given_args, **given_kwargs) + # Then + assert output == expected_output + + +def test_wrap_class(): + # Given + @binding.wrap + class GivenClass: + def __init__(self, attr: int): + self.attr = attr + + given_attr = "1" + expected_attr = 1 + + # When + instance = GivenClass(given_attr) + # Then + assert instance.attr == expected_attr