Skip to content

Commit

Permalink
feat: Better generics interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
seandstewart committed Jun 19, 2024
1 parent 82b566c commit 0f96785
Showing 1 changed file with 80 additions and 40 deletions.
120 changes: 80 additions & 40 deletions src/typelib/unmarshal/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __call__(self, val: tp.Any) -> BytesT:
StrT = tp.TypeVar("StrT", bound=str)


class StrUnmarshaller(AbstractUnmarshaller[StrT]):
class StrUnmarshaller(AbstractUnmarshaller[StrT], tp.Generic[StrT]):
def __call__(self, val: tp.Any) -> StrT:
# Always decode bytes.
decoded = interchange.decode(val)
Expand All @@ -91,7 +91,7 @@ def __call__(self, val: tp.Any) -> StrT:
NumberT = tp.TypeVar("NumberT", bound=numbers.Number)


class NumberUnmarshaller(AbstractUnmarshaller[NumberT]):
class NumberUnmarshaller(AbstractUnmarshaller[NumberT], tp.Generic[NumberT]):
def __call__(self, val: tp.Any) -> NumberT:
# Always decode bytes.
decoded = interchange.decode(val)
Expand Down Expand Up @@ -119,8 +119,11 @@ def __call__(self, val: tp.Any) -> NumberT:
FractionUnmarshaller = NumberUnmarshaller[FractionT]


class DateUnmarshaller(AbstractUnmarshaller[datetime.date]):
def __call__(self, val: tp.Any) -> datetime.date:
DateT = tp.TypeVar("DateT", bound=datetime.date)


class DateUnmarshaller(AbstractUnmarshaller[DateT], tp.Generic[DateT]):
def __call__(self, val: tp.Any) -> DateT:
if isinstance(val, self.t) and not isinstance(val, datetime.datetime):
return val

Expand All @@ -140,12 +143,17 @@ def __call__(self, val: tp.Any) -> datetime.date:
return self.t.today()
# Exact class matching - the parser returns subclasses.
if date.__class__ is self.t:
return date
return date # type: ignore[return-value]
# Reconstruct as the exact type.
return self.t(year=date.year, month=date.month, day=date.day)


class DateTimeUnmarshaller(AbstractUnmarshaller[datetime.datetime]):
DateTimeT = tp.TypeVar("DateTimeT", bound=datetime.datetime)


class DateTimeUnmarshaller(
AbstractUnmarshaller[datetime.datetime], tp.Generic[DateTimeT]
):
def __call__(self, val: tp.Any) -> datetime.datetime:
if isinstance(val, self.t):
return val
Expand Down Expand Up @@ -190,18 +198,22 @@ def __call__(self, val: tp.Any) -> datetime.datetime:
return self.t(year=dt.year, month=dt.month, day=dt.day)


class TimeUnmarshaller(AbstractUnmarshaller[datetime.time]):
def __call__(self, val: tp.Any) -> datetime.time:
TimeT = tp.TypeVar("TimeT", bound=datetime.time)


class TimeUnmarshaller(AbstractUnmarshaller[TimeT], tp.Generic[TimeT]):
def __call__(self, val: tp.Any) -> TimeT:
if isinstance(val, self.t):
return val
if isinstance(val, (int, float)):
return (

decoded = interchange.decode(val)
if isinstance(decoded, (int, float)):
decoded = (
datetime.datetime.fromtimestamp(val, tz=datetime.timezone.utc)
.time()
# datetime.time() strips tzinfo...
.replace(tzinfo=datetime.timezone.utc)
)
decoded = interchange.decode(val)
dt: datetime.datetime | datetime.date | datetime.time = (
interchange.dateparse(decoded, self.t)
if isinstance(decoded, str)
Expand All @@ -214,7 +226,7 @@ def __call__(self, val: tp.Any) -> datetime.time:
dt = self.t()

if dt.__class__ is self.t:
return dt
return dt # type: ignore[return-value]

return self.t(
hour=dt.hour,
Expand All @@ -226,8 +238,11 @@ def __call__(self, val: tp.Any) -> datetime.time:
)


class TimeDeltaUnmarshaller(AbstractUnmarshaller[datetime.timedelta]):
def __call__(self, val: tp.Any) -> datetime.timedelta:
TimeDeltaT = tp.TypeVar("TimeDeltaT", bound=datetime.timedelta)


class TimeDeltaUnmarshaller(AbstractUnmarshaller[TimeDeltaT], tp.Generic[TimeDeltaT]):
def __call__(self, val: tp.Any) -> TimeDeltaT:
if isinstance(val, (int, float)):
return self.t(seconds=int(val))

Expand All @@ -239,13 +254,16 @@ def __call__(self, val: tp.Any) -> datetime.timedelta:
)

if td.__class__ is self.t:
return td
return td # type: ignore[return-value]

return self.t(seconds=td.total_seconds())


class UUIDUnmarshaller(AbstractUnmarshaller[uuid.UUID]):
def __call__(self, val: tp.Any) -> uuid.UUID:
UUIDT = tp.TypeVar("UUIDT", bound=uuid.UUID)


class UUIDUnmarshaller(AbstractUnmarshaller[UUIDT], tp.Generic[UUIDT]):
def __call__(self, val: tp.Any) -> UUIDT:
decoded = (
interchange.strload(val) if inspection.istexttype(val.__class__) else val
)
Expand All @@ -254,10 +272,13 @@ def __call__(self, val: tp.Any) -> uuid.UUID:
return self.t(decoded)


class PatternUnmarshaller(AbstractUnmarshaller[re.Pattern]):
def __call__(self, val: tp.Any) -> re.Pattern:
PatternT = tp.TypeVar("PatternT", bound=re.Pattern)


class PatternUnmarshaller(AbstractUnmarshaller[PatternT], tp.Generic[PatternT]):
def __call__(self, val: tp.Any) -> PatternT:
decoded = interchange.decode(val)
return re.compile(decoded)
return re.compile(decoded) # type: ignore[return-value]


class CastUnmarshaller(AbstractUnmarshaller[T]):
Expand All @@ -283,18 +304,18 @@ def __call__(self, val: tp.Any) -> T:
MappingUnmarshaller = CastUnmarshaller[tp.Mapping]
IterableUnmarshaller = CastUnmarshaller[tp.Iterable]

_LTVT = tp.TypeVarTuple("_LTVT")
_LT = tp.TypeVar("_LT")

LiteralT = tp.TypeVar("LiteralT")


class LiteralUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]):
class LiteralUnmarshaller(AbstractUnmarshaller[LiteralT], tp.Generic[LiteralT]):
__slots__ = ("values",)

def __init__(self, t: type[T], context: ContextT):
def __init__(self, t: type[LiteralT], context: ContextT):
super().__init__(t, context)
self.values = inspection.get_args(t)

def __call__(self, val: tp.Any) -> tp.Literal[*_LTVT]: # type: ignore[valid-type]
def __call__(self, val: tp.Any) -> LiteralT:
if val in self.values:
return val
decoded = (
Expand All @@ -306,15 +327,18 @@ def __call__(self, val: tp.Any) -> tp.Literal[*_LTVT]: # type: ignore[valid-typ
raise ValueError(f"{decoded!r} is not one of {self.values!r}")


class UnionUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]):
UnionT = tp.TypeVar("UnionT")


class UnionUnmarshaller(AbstractUnmarshaller[UnionT], tp.Generic[UnionT]):
__slots__ = ("stack", "ordered_routines")

def __init__(self, t: type[T], context: ContextT):
def __init__(self, t: type[UnionT], context: ContextT):
super().__init__(t, context)
self.stack = inspection.get_args(t)
self.ordered_routines = [self.context[typ] for typ in self.stack]

def __call__(self, val: tp.Any) -> tp.Union[*_LTVT]: # type: ignore[valid-type]
def __call__(self, val: tp.Any) -> UnionT:
for routine in self.ordered_routines:
with contextlib.suppress(ValueError, TypeError, SyntaxError):
unmarshalled = routine(val)
Expand All @@ -327,19 +351,24 @@ def __call__(self, val: tp.Any) -> tp.Union[*_LTVT]: # type: ignore[valid-type]
_VT = tp.TypeVar("_VT")


class SubscriptedMappingUnmarshaller(AbstractUnmarshaller[tp.Mapping[_KT, _VT]]):
MappingT = tp.TypeVar("MappingT", bound=tp.Mapping)


class SubscriptedMappingUnmarshaller(
AbstractUnmarshaller[MappingT], tp.Generic[MappingT]
):
__slots__ = (
"keys",
"values",
)

def __init__(self, t: type[tp.Mapping[_KT, _VT]], context: ContextT) -> None:
def __init__(self, t: type[MappingT], context: ContextT):
super().__init__(t=t, context=context)
key_t, value_t = inspection.get_args(t)
self.keys = context[key_t]
self.values = context[value_t]

def __call__(self, val: tp.Any) -> tp.Mapping[_KT, _VT]:
def __call__(self, val: tp.Any) -> MappingT:
# Always decode bytes.
decoded = interchange.strload(val)
keys = self.keys
Expand All @@ -349,34 +378,45 @@ def __call__(self, val: tp.Any) -> tp.Mapping[_KT, _VT]:
)


class SubscriptedIterableUnmarshaller(AbstractUnmarshaller[tp.Iterable[_VT]]):
IterableT = tp.TypeVar("IterableT", bound=tp.Iterable)


class SubscriptedIterableUnmarshaller(
AbstractUnmarshaller[IterableT], tp.Generic[IterableT]
):
__slots__ = ("values",)

def __init__(self, t: type[tp.Iterable[_VT]], context: ContextT) -> None:
def __init__(self, t: type[IterableT], context: ContextT):
super().__init__(t=t, context=context)
(value_t,) = inspection.get_args(t)
self.values = context[value_t]

def __call__(self, val: tp.Any) -> tp.Iterable[_VT]:
def __call__(self, val: tp.Any) -> IterableT:
# Always decode bytes.
decoded = interchange.strload(val)
values = self.values
return self.origin((values(v) for v in interchange.itervalues(decoded)))


class SubscriptedIteratorUnmarshaller(AbstractUnmarshaller[tp.Iterator[_VT]]):
IteratorT = tp.TypeVar("IteratorT", bound=tp.Iterator)


class SubscriptedIteratorUnmarshaller(
AbstractUnmarshaller[IteratorT], tp.Generic[IteratorT]
):
__slots__ = ("values",)

def __init__(self, t: type[tp.Iterator[_VT]], context: ContextT) -> None:
def __init__(self, t: type[IteratorT], context: ContextT):
super().__init__(t=t, context=context)
(value_t,) = inspection.get_args(t)
self.values = context[value_t]

def __call__(self, val: tp.Any) -> tp.Iterator[_VT]:
def __call__(self, val: tp.Any) -> IteratorT:
# Always decode bytes.
decoded = interchange.strload(val)
values = self.values
return (values(v) for v in interchange.itervalues(decoded))
it: IteratorT = (values(v) for v in interchange.itervalues(decoded)) # type: ignore[assignment]
return it


_TVT = tp.TypeVarTuple("_TVT")
Expand All @@ -385,7 +425,7 @@ def __call__(self, val: tp.Any) -> tp.Iterator[_VT]:
class FixedTupleUnmarshaller(AbstractUnmarshaller[tuple[*_TVT]]):
__slots__ = ("ordered_routines", "stack")

def __init__(self, t: type[tuple[*_TVT]], context: ContextT) -> None:
def __init__(self, t: type[tuple[*_TVT]], context: ContextT):
super().__init__(t, context)
self.stack = inspection.get_args(t)
self.ordered_routines = [self.context[vt] for vt in self.stack]
Expand All @@ -406,7 +446,7 @@ def __call__(self, val: tp.Any) -> tuple[*_TVT]:
class StructuredTypeUnmarshaller(AbstractUnmarshaller[_ST]):
__slots__ = ("fields_by_var",)

def __init__(self, t: type[_ST], context: ContextT) -> None:
def __init__(self, t: type[_ST], context: ContextT):
super().__init__(t, context)
self.fields_by_var = {m.var: m for m in self.context.values()}

Expand Down

0 comments on commit 0f96785

Please sign in to comment.