diff --git a/poetry.lock b/poetry.lock index f9e4bbb..311dc7a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -114,6 +114,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "more-itertools" +version = "10.2.0" +description = "More routines for operating on iterables, beyond itertools" +optional = false +python-versions = ">=3.8" +files = [ + {file = "more-itertools-10.2.0.tar.gz", hash = "sha256:8fccb480c43d3e99a00087634c06dd02b0d50fbf088b380de5a41a015ec239e1"}, + {file = "more_itertools-10.2.0-py3-none-any.whl", hash = "sha256:686b06abe565edfab151cb8fd385a05651e1fdf8f0a14191e4439283421f8684"}, +] + [[package]] name = "mypy" version = "1.10.0" @@ -554,4 +565,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6f9ab291daa4648c5a1e9a34d7c242c661238382e1b41040fe1aa2660b439bfe" +content-hash = "f5fd146496775e0fed5b8cb675ab63aca3054fe29e83c2df2af5d5a9b1b80900" diff --git a/pyproject.toml b/pyproject.toml index 1f770a8..4bd96ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ inflection = "^0.5" pendulum = "^3" orjson = {version = "^3", optional = true} typing-extensions = {version = "^4.10"} +more-itertools = "^10.2.0" [tool.poetry.group.test.dependencies] pytest = "^8" @@ -112,8 +113,6 @@ extend-select = [ # Future annotation "FA" ] -[tool.ruff.lint.mccabe] -max-complexity = 15 [tool.ruff.lint.per-file-ignores] # Ignore `E402` (import violations) in all `__init__.py` files diff --git a/src/typelib/future.py b/src/typelib/future.py index 2b0df1b..a35c1bf 100644 --- a/src/typelib/future.py +++ b/src/typelib/future.py @@ -5,7 +5,6 @@ import functools import sys import typing -from ast import unparse __all__ = ("transform_annotation",) @@ -26,7 +25,7 @@ def transform(annotation: str, *, union: str = "Union") -> str: """ parsed = ast.parse(annotation, mode="eval") transformed = TransformUnion().generic_visit(parsed) - unparsed = unparse(transformed).strip() + unparsed = ast.unparse(transformed).strip() return unparsed diff --git a/src/typelib/inspection.py b/src/typelib/inspection.py index a545ae5..e0cab8b 100644 --- a/src/typelib/inspection.py +++ b/src/typelib/inspection.py @@ -9,11 +9,13 @@ import datetime import decimal import enum +import fractions import functools import inspect import ipaddress import numbers import pathlib +import re import sqlite3 import types import typing @@ -39,7 +41,7 @@ overload, ) -from typelib import compat, contrib, refs +from typelib import compat, constants, contrib, refs __all__ = ( "BUILTIN_TYPES", @@ -47,6 +49,7 @@ "isbuiltininstance", "isbuiltintype", "isbuiltinsubtype", + "isbytestype", "isclassvartype", "iscollectiontype", "isdatetype", @@ -70,6 +73,7 @@ "isstdlibinstance", "isstdlibtype", "isstdlibsubtype", + "isstringtype", "istexttype", "istimetype", "istimedeltatype", @@ -77,6 +81,7 @@ "istypeddict", "istypedtuple", "isuniontype", + "isunresolvable", "isuuidtype", "should_unwrap", ) @@ -101,7 +106,7 @@ def origin(annotation: Any) -> Any: >>> class Foo: ... ... >>> origin(Foo) - + """ # Resolve custom NewTypes. actual = resolve_supertype(annotation) @@ -164,14 +169,14 @@ def get_args(annotation: Any) -> Tuple[Any, ...]: Examples: >>> from typelib import inspection - >>> from typing import Dict, TypeVar + >>> from typing import Dict, TypeVar, Any >>> T = TypeVar("T") >>> get_args(Dict) () >>> get_args(Dict[str, int]) (, ) >>> get_args(Dict[str, T]) - (,) + (, typing.Any) """ args = typing.get_args(annotation) if not args: @@ -681,6 +686,22 @@ def isdecimaltype(obj: type) -> compat.TypeIs[type[decimal.Decimal]]: return builtins.issubclass(origin(obj), decimal.Decimal) +@compat.cache +def isfractiontype(obj: type) -> compat.TypeIs[type[fractions.Fraction]]: + """Test whether this annotation is a Decimal object. + + Examples: + + >>> import fractions + >>> from typing import NewType + >>> isdecimaltype(fractions.Fraction) + True + >>> isdecimaltype(NewType("Foo", fractions.Fraction)) + True + """ + return builtins.issubclass(origin(obj), fractions.Fraction) + + @compat.cache def isuuidtype(obj: type) -> compat.TypeIs[type[uuid.UUID]]: """Test whether this annotation is a a date/datetime object. @@ -1167,7 +1188,35 @@ def istexttype(t: type[Any]) -> compat.TypeIs[type[str | bytes | bytearray]]: >>> istexttype(MyStr) True """ - return issubclass(t, (str, bytes, bytearray)) + return issubclass(t, (str, bytes, bytearray, memoryview)) + + +@compat.cache +def isstringtype(t: type[Any]) -> compat.TypeIs[type[str | bytes | bytearray]]: + """Test whether the given type is a subclass of text or bytes. + + Examples: + + >>> class MyStr(str): ... + ... + >>> istexttype(MyStr) + True + """ + return issubclass(t, str) + + +@compat.cache +def isbytestype(t: type[Any]) -> compat.TypeIs[type[str | bytes | bytearray]]: + """Test whether the given type is a subclass of text or bytes. + + Examples: + + >>> class MyStr(str): ... + ... + >>> istexttype(MyStr) + True + """ + return issubclass(t, (bytes, bytearray)) @compat.cache @@ -1274,9 +1323,63 @@ def issubscriptedgeneric(t: Any) -> bool: @compat.cache # type: ignore[arg-type] def iscallable(t: Any) -> compat.TypeIs[Callable]: + """Test whether the given type is a callable. + + Examples: + >>> import typing + >>> import collections.abc + >>> iscallable(lambda: None) + True + >>> iscallable(typing.Callable) + True + >>> iscallable(collections.abc.Callable) + True + >>> iscallable(1) + False + """ return inspect.isroutine(t) or t is Callable or _safe_issubclass(t, abc_Callable) # type: ignore[arg-type] +@compat.cache +def isunresolvable(t: Any) -> bool: + """Test whether the given type is unresolvable. + + Examples: + >>> import typing + >>> isunresolvable(int) + False + >>> isunresolvable(typing.Any) + True + >>> isunresolvable(...) + True + """ + return t in _UNRESOLVABLE + + +_UNRESOLVABLE = ( + Any, + re.Match, + type(None), + None, + constants.empty, + Callable, + abc_Callable, + inspect.Parameter.empty, + type(Ellipsis), + Ellipsis, +) + + +@compat.cache +def ispatterntype(t: Any) -> compat.TypeIs[re.Pattern]: + return issubclass(t, re.Pattern) + + +@compat.cache +def ispathtype(t: Any) -> compat.TypeIs[pathlib.Path]: + return issubclass(t, pathlib.PurePath) + + def _safe_issubclass(__cls: type, __class_or_tuple: type | tuple[type, ...]) -> bool: try: return issubclass(__cls, __class_or_tuple) diff --git a/src/typelib/interchange.py b/src/typelib/interchange.py index 034848c..d20e9dc 100644 --- a/src/typelib/interchange.py +++ b/src/typelib/interchange.py @@ -1,12 +1,17 @@ from __future__ import annotations +import ast +import contextlib import datetime +import functools +import operator import time import typing as t import pendulum +from more_itertools import peekable -from typelib import compat +from typelib import compat, inspection @t.overload @@ -99,43 +104,144 @@ def unixtime(t: datetime.date | datetime.time) -> float: @compat.lru_cache(maxsize=100_000) -def dateparse(val: str, t: type[DateTimeT]) -> DateTimeT: +def dateparse(val: str, td: type[DateTimeT]) -> DateTimeT: """Parse a date string into a datetime object. Args: val: The date string to parse. - t: The target datetime type. + td: The target datetime type. Returns: The parsed datetime object. """ try: + # When `exact=False`, the only two possibilities are DateTime and Duration. parsed: pendulum.DateTime | pendulum.Duration = pendulum.parse(val) # type: ignore[assignment] - if isinstance(parsed, pendulum.DateTime): - if issubclass(t, datetime.time): - return parsed.time().replace(tzinfo=parsed.tzinfo) - if issubclass(t, datetime.datetime): - return parsed - if issubclass(t, datetime.date): - return parsed.date() - if not isinstance(parsed, t): - raise ValueError(f"Cannot parse {val!r} as {t.__qualname__!r}") - return parsed + normalized = _nomalize_dt(val=val, parsed=parsed, td=td) + return normalized except ValueError: if val.isdigit() or val.isdecimal(): - numval = float(val) - # Assume the number value is seconds - same logic as time-since-epoch - if issubclass(t, datetime.timedelta): - return datetime.timedelta(seconds=numval) - # Parse a datetime from the time-since-epoch as indicated by the value. - dt = datetime.datetime.fromtimestamp(numval, tz=datetime.timezone.utc) - # Return the datetime if the target type is a datetime - if issubclass(t, datetime.datetime): - return dt - # If the target type is a time object, just return the time. - if issubclass(t, datetime.time): - return dt.time().replace(tzinfo=dt.tzinfo) - # If the target type is a date object, just return the date. - return dt.date() - + return _normalize_number(numval=float(val), td=td) raise + + +def _nomalize_dt( + *, val: str, parsed: pendulum.DateTime | pendulum.Duration, td: type[DateTimeT] +) -> DateTimeT: + if isinstance(parsed, pendulum.DateTime): + if issubclass(td, datetime.time): + return parsed.time().replace(tzinfo=parsed.tzinfo) + if issubclass(td, datetime.datetime): + return parsed + if issubclass(td, datetime.date): + return parsed.date() + if not isinstance(parsed, td): + raise ValueError(f"Cannot parse {val!r} as {td.__qualname__!r}") + return parsed + + +def _normalize_number(*, numval: float, td: type[DateTimeT]) -> DateTimeT: + # Assume the number value is seconds - same logic as time-since-epoch + if issubclass(td, datetime.timedelta): + return datetime.timedelta(seconds=numval) + # Parse a datetime from the time-since-epoch as indicated by the value. + dt = datetime.datetime.fromtimestamp(numval, tz=datetime.timezone.utc) + # Return the datetime if the target type is a datetime + if issubclass(td, datetime.datetime): + return dt + # If the target type is a time object, just return the time. + if issubclass(td, datetime.time): + return dt.time().replace(tzinfo=dt.tzinfo) + # If the target type is a date object, just return the date. + return dt.date() + + +def iteritems(val: t.Any) -> t.Iterable[tuple[t.Any, t.Any]]: + if _is_iterable_of_pairs(val): + return iter(val) + + iterate = get_items_iter(val.__class__) + return iterate(val) + + +def _is_iterable_of_pairs(val: t.Any) -> bool: + if not inspection.isiterabletype(val.__class__): + return False + peek = peekable(val).peek() + return inspection.iscollectiontype(peek.__class__) and len(peek) == 2 + + +def itervalues(val: t.Any) -> t.Iterator[t.Any]: + iterate = get_items_iter(val.__class__) + return (v for k, v in iterate(val)) + + +@functools.cache +def get_items_iter(tp: type) -> t.Callable[[t.Any], t.Iterable[tuple[t.Any, t.Any]]]: + ismapping, isnamedtuple, isiterable, isstructured = ( + inspection.ismappingtype(tp), + inspection.isnamedtuple(tp), + inspection.isiterabletype(tp), + inspection.isstructuredtype(tp), + ) + if ismapping: + return _itemscaller + if isnamedtuple: + return _namedtupleitems + if isiterable: + return enumerate + if isstructured: + return _make_fields_iterator(tp) + raise TypeError(f"Cannot iterate items of type {tp.__qualname__!r}") + + +def _namedtupleitems(val: t.NamedTuple) -> t.Iterable[tuple[str, t.Any]]: + return val._asdict().items() + + +def _make_fields_iterator( + tp: type, +) -> t.Callable[[t.Any], t.Iterator[tuple[t.Any, t.Any]]]: + attribs = inspection.get_type_hints(tp) + public_attribs = [k for k in attribs if not k.startswith("_")] + if not public_attribs and hasattr(tp, "__slots__"): + public_attribs = [s for s in tp.__slots__ if not s.startswith("_")] + + if public_attribs: + + def _iterfields(val: t.Any) -> t.Iterator[tuple[str, t.Any]]: + return (getattr(val, a) for a in public_attribs) + + return _iterfields + + def _itervars(val: t.Any) -> t.Iterator[tuple[str, t.Any]]: + return ((k, v) for k, v in val.__dict__.items() if not k.startswith("_")) + + return _itervars + + +@compat.lru_cache(maxsize=100_000) +def strload(val: str | bytes | bytearray | memoryview) -> PythonValueT: + """Attempt to load""" + with contextlib.suppress(ValueError): + return compat.json.loads(val) + + decoded = decode(val) + with contextlib.suppress(ValueError, TypeError, SyntaxError): + return ast.literal_eval(decoded) + + return decoded + + +PythonPrimitiveT: t.TypeAlias = bool | int | float | str | None +PythonValueT: t.TypeAlias = ( + "PythonPrimitiveT | " + "dict[PythonPrimitiveT, PythonValueT] | " + "list[PythonValueT] | " + "tuple[PythonValueT, ...] | " + "set[PythonValueT]" +) + + +_itemscaller = operator.methodcaller("items") +_valuescaller = operator.methodcaller("values") diff --git a/src/typelib/unmarshal/factory.py b/src/typelib/unmarshal/factory.py index 883d421..9f13b78 100644 --- a/src/typelib/unmarshal/factory.py +++ b/src/typelib/unmarshal/factory.py @@ -1,24 +1,27 @@ from __future__ import annotations -import collections.abc as cabc -import inspect -import re -import typing as t +import typing as tp -from typelib import compat, constants, graph +from typelib import compat, graph, inspection, refs from typelib.unmarshal import routines -T = t.TypeVar("T") +T = tp.TypeVar("T") @compat.cache def unmarshaller(typ: type[T]) -> routines.AbstractUnmarshaller[T]: nodes = graph.static_order(typ) context: dict[type, routines.AbstractUnmarshaller] = {} + if not nodes: + return NoOpUnmarshaller(t=typ, context=context) + + root = typ for node in nodes: context[node.type] = _get_unmarshaller(node.type, context=context) + # root will be the last seen node + root = node.type - return context[typ] + return context[root] def _get_unmarshaller( # type: ignore[return] @@ -31,56 +34,80 @@ def _get_unmarshaller( # type: ignore[return] if check(typ): return unmarshaller_cls(typ, context=context) - # TODO: fields unmarshaller + return routines.StructuredTypeUnmarshaller(typ, context=context) -_T = t.TypeVar("_T") +class DelayedUnmarshaller(routines.AbstractUnmarshaller[T]): + def __init__(self, t: type[T], context: routines.ContextT): + super().__init__(t, context) + self._resolved: routines.AbstractUnmarshaller[T] | None = None + @property + def resolved(self) -> routines.AbstractUnmarshaller[T]: + if self._resolved is None: + self._resolved = self._resolve_unmarshaller() + self._resolved.__class__.__init__( + self, # type: ignore[arg-type] + self._resolved.t, # type: ignore[arg-type] + self._resolved.context, # type: ignore[arg-type] + ) + return self._resolved + + def _resolve_unmarshaller(self) -> routines.AbstractUnmarshaller[T]: + typ = refs.evaluate(self.t) # type: ignore[arg-type] + um = unmarshaller(typ) + return um + + def __call__(self, val: tp.Any) -> T: + return self.resolved(val) -_UNRESOLVABLE = frozenset( - ( - t.Any, - re.Match, - type(None), - constants.empty, - t.Callable, - cabc.Callable, - inspect.Parameter.empty, - ) -) + +class NoOpUnmarshaller(routines.AbstractUnmarshaller[T]): + def __call__(self, val: tp.Any) -> T: + return tp.cast(T, val) # Order is IMPORTANT! This is a FIFO queue. -_HANDLERS: t.Mapping[ - t.Callable[[type[T]], bool], type[routines.AbstractUnmarshaller] +_HANDLERS: tp.Mapping[ + tp.Callable[[type[T]], bool], type[routines.AbstractUnmarshaller] ] = { - # # Short-circuit forward refs - # inspection.isforwardref: ..., - # # Special handler for Literals + # Short-circuit forward refs + inspection.isforwardref: DelayedUnmarshaller, + inspection.isunresolvable: NoOpUnmarshaller, + # Special handler for Literals # inspection.isliteral: ..., - # # Special handler for Unions... + # Special handler for Unions... # inspection.isuniontype: ..., - # # Non-intersecting types (order doesn't matter here. - # inspection.isdatetimetype: ..., - # inspection.isdatetype: ..., - # inspection.istimetype: ..., - # inspection.istimedeltatype: ..., - # inspection.isuuidtype: ..., - # inspection.ispatterntype: ..., - # inspection.ispathtype: ..., - # inspection.isdecimaltype: ..., - # inspection.istexttype: ..., - # # MUST come before subtype check. - # inspection.isbuiltintype: ..., - # # Psuedo-structured containers, should check before generics. - # inspection.istypeddict: ..., - # inspection.istypedtuple: ..., - # inspection.isnamedtuple: ..., - # inspection.isfixedtupletype: ..., - # # A mapping is a collection so must come before that check. - # inspection.ismappingtype: ..., - # # A tuple is a collection so must come before that check. - # inspection.istupletype: ..., - # # Generic collection handler - # inspection.iscollectiontype: ..., + # Non-intersecting types (order doesn't matter here. + inspection.isdatetimetype: routines.DateTimeUnmarshaller, + inspection.isdatetype: routines.DateUnmarshaller, + inspection.istimetype: routines.TimeUnmarshaller, + inspection.istimedeltatype: routines.TimeDeltaUnmarshaller, + inspection.isuuidtype: routines.UUIDUnmarshaller, + inspection.ispatterntype: routines.PatternUnmarshaller, + inspection.ispathtype: routines.PathUnmarshaller, + inspection.isdecimaltype: routines.DecimalUnmarshaller, + inspection.isfractiontype: routines.FractionUnmarshaller, + inspection.isstringtype: routines.StrUnmarshaller, + inspection.isbytestype: routines.BytesUnmarshaller, + # Psuedo-structured containers, should check before generics. + inspection.istypeddict: routines.StructuredTypeUnmarshaller, + inspection.istypedtuple: routines.StructuredTypeUnmarshaller, + inspection.isnamedtuple: routines.StructuredTypeUnmarshaller, + inspection.isfixedtupletype: routines.FixedTupleUnmarshaller, + ( + lambda t: inspection.issubscriptedgeneric(t) and inspection.ismappingtype(t) + ): routines.SubscriptedMappingUnmarshaller, + ( + lambda t: inspection.issubscriptedgeneric(t) and inspection.isiteratortype(t) + ): routines.SubscriptedIteratorUnmarshaller, + ( + lambda t: inspection.issubscriptedgeneric(t) and inspection.isiterabletype(t) + ): routines.SubscriptedIterableUnmarshaller, + # A mapping is a collection so must come before that check. + inspection.ismappingtype: routines.MappingUnmarshaller, + # Generic iterator handler + inspection.isiteratortype: NoOpUnmarshaller[tp.Iterator], + # Generic Iterable handler + inspection.isiterabletype: routines.IterableUnmarshaller, } diff --git a/src/typelib/unmarshal/routines.py b/src/typelib/unmarshal/routines.py index 795786b..9079817 100644 --- a/src/typelib/unmarshal/routines.py +++ b/src/typelib/unmarshal/routines.py @@ -1,31 +1,65 @@ from __future__ import annotations import abc +import contextlib import datetime import decimal import fractions import numbers +import pathlib +import re import typing as tp +import uuid from typelib import inspection, interchange T = tp.TypeVar("T") +__all__ = ( + "AbstractUnmarshaller", + "ContextT", + "BytesUnmarshaller", + "StrUnmarshaller", + "NumberUnmarshaller", + "DecimalUnmarshaller", + "FractionUnmarshaller", + "DateUnmarshaller", + "DateTimeUnmarshaller", + "TimeUnmarshaller", + "TimeDeltaUnmarshaller", + "UUIDUnmarshaller", + "PathUnmarshaller", + "CastUnmarshaller", + "PatternUnmarshaller", + "MappingUnmarshaller", + "IterableUnmarshaller", + "LiteralUnmarshaller", + "UnionUnmarshaller", + "SubscriptedIteratorUnmarshaller", + "SubscriptedIterableUnmarshaller", + "SubscriptedMappingUnmarshaller", + "FixedTupleUnmarshaller", + "StructuredTypeUnmarshaller", +) + class AbstractUnmarshaller(abc.ABC, tp.Generic[T]): - context: tp.Mapping[type, AbstractUnmarshaller] + context: ContextT t: type[T] - __slots__ = ("t", "context") + __slots__ = ("t", "origin", "context", "var") - def __init__(self, t: type[T], context: tp.Mapping[type, AbstractUnmarshaller]): + def __init__(self, t: type[T], context: ContextT, *, var: str | None = None): self.t = t + self.origin = inspection.origin(self.t) self.context = context + self.var = var @abc.abstractmethod def __call__(self, val: tp.Any) -> T: ... +ContextT: tp.TypeAlias = tp.Mapping[type, AbstractUnmarshaller] BytesT = tp.TypeVar("BytesT", bound=bytes) @@ -208,3 +242,180 @@ def __call__(self, val: tp.Any) -> datetime.timedelta: return td return self.t(seconds=td.total_seconds()) + + +class UUIDUnmarshaller(AbstractUnmarshaller[uuid.UUID]): + def __call__(self, val: tp.Any) -> uuid.UUID: + decoded = ( + interchange.strload(val) if inspection.istexttype(val.__class__) else val + ) + if isinstance(decoded, int): + return self.t(int=decoded) + return self.t(decoded) + + +class PatternUnmarshaller(AbstractUnmarshaller[re.Pattern]): + def __call__(self, val: tp.Any) -> re.Pattern: + decoded = interchange.decode(val) + return re.compile(decoded) + + +class CastUnmarshaller(AbstractUnmarshaller[T]): + __slots__ = ("caster",) + + def __init__(self, t: type[T], context: ContextT): + super().__init__(t, context) + self.caster: tp.Callable[[tp.Any], T] = self.origin # type: ignore[assignment] + + def __call__(self, val: tp.Any) -> T: + # Try to load the string, if this is JSON or a literal expression. + decoded = ( + interchange.strload(val) if inspection.istexttype(val.__class__) else val + ) + # Short-circuit cast if we have the type we want. + if decoded.__class__ is self.origin: + return decoded + # Cast the decoded value to the type. + return self.caster(decoded) + + +PathUnmarshaller = CastUnmarshaller[pathlib.Path] +MappingUnmarshaller = CastUnmarshaller[tp.Mapping] +IterableUnmarshaller = CastUnmarshaller[tp.Iterable] + +_LTVT = tp.TypeVarTuple("_LTVT") +_LT = tp.TypeVar("_LT") + + +class LiteralUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]): + __slots__ = ("values",) + + def __init__(self, t: type[T], context: ContextT): + super().__init__(t, context) + self.values = inspection.get_args(t) + + def __call__(self, val: tp.Any) -> tp.Literal[*_LTVT]: # type: ignore[valid-type] + if val in self.values: + return val + decoded = ( + interchange.strload(val) if inspection.istexttype(val.__class__) else val + ) + if decoded in self.values: + return decoded + + raise ValueError(f"{decoded!r} is not one of {self.values!r}") + + +class UnionUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]): + __slots__ = ("stack", "ordered_routines") + + def __init__(self, t: type[T], context: ContextT): + super().__init__(t, context) + self.stack = inspection.get_args(t) + self.ordered_routines = [self.context[typ] for typ in self.stack] + + def __call__(self, val: tp.Any) -> tp.Union[*_LTVT]: # type: ignore[valid-type] + for routine in self.ordered_routines: + with contextlib.suppress(ValueError, TypeError, SyntaxError): + unmarshalled = routine(val) + return unmarshalled + + raise ValueError(f"{val!r} is not one of types {self.stack!r}") + + +_KT = tp.TypeVar("_KT") +_VT = tp.TypeVar("_VT") + + +class SubscriptedMappingUnmarshaller(AbstractUnmarshaller[tp.Mapping[_KT, _VT]]): + __slots__ = ( + "keys", + "values", + ) + + def __init__(self, t: type[tp.Mapping[_KT, _VT]], context: ContextT) -> None: + super().__init__(t=t, context=context) + key_t, value_t = inspection.get_args(t) + self.keys = context[key_t] + self.values = context[value_t] + + def __call__(self, val: tp.Any) -> tp.Mapping[_KT, _VT]: + # Always decode bytes. + decoded = interchange.strload(val) + keys = self.keys + values = self.values + return self.origin( + ((keys(k), values(v)) for k, v in interchange.iteritems(decoded)) + ) + + +class SubscriptedIterableUnmarshaller(AbstractUnmarshaller[tp.Iterable[_VT]]): + __slots__ = ("values",) + + def __init__(self, t: type[tp.Iterable[_VT]], context: ContextT) -> None: + super().__init__(t=t, context=context) + (value_t,) = inspection.get_args(t) + self.values = context[value_t] + + def __call__(self, val: tp.Any) -> tp.Iterable[_VT]: + # Always decode bytes. + decoded = interchange.strload(val) + values = self.values + return self.origin((values(v) for v in interchange.itervalues(decoded))) + + +class SubscriptedIteratorUnmarshaller(AbstractUnmarshaller[tp.Iterator[_VT]]): + __slots__ = ("values",) + + def __init__(self, t: type[tp.Iterator[_VT]], context: ContextT) -> None: + super().__init__(t=t, context=context) + (value_t,) = inspection.get_args(t) + self.values = context[value_t] + + def __call__(self, val: tp.Any) -> tp.Iterator[_VT]: + # Always decode bytes. + decoded = interchange.strload(val) + values = self.values + return (values(v) for v in interchange.itervalues(decoded)) + + +_TVT = tp.TypeVarTuple("_TVT") + + +class FixedTupleUnmarshaller(AbstractUnmarshaller[tuple[*_TVT]]): + __slots__ = ("ordered_routines", "stack") + + def __init__(self, t: type[tuple[*_TVT]], context: ContextT) -> None: + super().__init__(t, context) + self.stack = inspection.get_args(t) + self.ordered_routines = [self.context[vt] for vt in self.stack] + + def __call__(self, val: tp.Any) -> tuple[*_TVT]: + decoded = interchange.strload(val) + return self.origin( + routine(v) + for routine, v in zip( + self.ordered_routines, interchange.itervalues(decoded) + ) + ) + + +_ST = tp.TypeVar("_ST") + + +class StructuredTypeUnmarshaller(AbstractUnmarshaller[_ST]): + __slots__ = ("fields_by_var",) + + def __init__(self, t: type[_ST], context: ContextT) -> None: + super().__init__(t, context) + self.fields_by_var = {m.var: m for m in self.context.values()} + + def __call__(self, val: tp.Any) -> _ST: + decoded = ( + interchange.strload(val) if inspection.istexttype(val.__class__) else val + ) + fields = self.fields_by_var + kwargs = { + f: fields[f](v) for f, v in interchange.iteritems(decoded) if f in fields + } + return self.t(**kwargs)