diff --git a/src/typelib/unmarshal/routines.py b/src/typelib/unmarshal/routines.py index 9079817..3e1fae5 100644 --- a/src/typelib/unmarshal/routines.py +++ b/src/typelib/unmarshal/routines.py @@ -76,7 +76,7 @@ def __call__(self, val: tp.Any) -> BytesT: StrT = tp.TypeVar("StrT", bound=str) -class StrUnmarshaller(AbstractUnmarshaller[StrT]): +class StrUnmarshaller(AbstractUnmarshaller[StrT], tp.Generic[StrT]): def __call__(self, val: tp.Any) -> StrT: # Always decode bytes. decoded = interchange.decode(val) @@ -91,7 +91,7 @@ def __call__(self, val: tp.Any) -> StrT: NumberT = tp.TypeVar("NumberT", bound=numbers.Number) -class NumberUnmarshaller(AbstractUnmarshaller[NumberT]): +class NumberUnmarshaller(AbstractUnmarshaller[NumberT], tp.Generic[NumberT]): def __call__(self, val: tp.Any) -> NumberT: # Always decode bytes. decoded = interchange.decode(val) @@ -119,8 +119,11 @@ def __call__(self, val: tp.Any) -> NumberT: FractionUnmarshaller = NumberUnmarshaller[FractionT] -class DateUnmarshaller(AbstractUnmarshaller[datetime.date]): - def __call__(self, val: tp.Any) -> datetime.date: +DateT = tp.TypeVar("DateT", bound=datetime.date) + + +class DateUnmarshaller(AbstractUnmarshaller[DateT], tp.Generic[DateT]): + def __call__(self, val: tp.Any) -> DateT: if isinstance(val, self.t) and not isinstance(val, datetime.datetime): return val @@ -140,12 +143,17 @@ def __call__(self, val: tp.Any) -> datetime.date: return self.t.today() # Exact class matching - the parser returns subclasses. if date.__class__ is self.t: - return date + return date # type: ignore[return-value] # Reconstruct as the exact type. return self.t(year=date.year, month=date.month, day=date.day) -class DateTimeUnmarshaller(AbstractUnmarshaller[datetime.datetime]): +DateTimeT = tp.TypeVar("DateTimeT", bound=datetime.datetime) + + +class DateTimeUnmarshaller( + AbstractUnmarshaller[datetime.datetime], tp.Generic[DateTimeT] +): def __call__(self, val: tp.Any) -> datetime.datetime: if isinstance(val, self.t): return val @@ -190,18 +198,22 @@ def __call__(self, val: tp.Any) -> datetime.datetime: return self.t(year=dt.year, month=dt.month, day=dt.day) -class TimeUnmarshaller(AbstractUnmarshaller[datetime.time]): - def __call__(self, val: tp.Any) -> datetime.time: +TimeT = tp.TypeVar("TimeT", bound=datetime.time) + + +class TimeUnmarshaller(AbstractUnmarshaller[TimeT], tp.Generic[TimeT]): + def __call__(self, val: tp.Any) -> TimeT: if isinstance(val, self.t): return val - if isinstance(val, (int, float)): - return ( + + decoded = interchange.decode(val) + if isinstance(decoded, (int, float)): + decoded = ( datetime.datetime.fromtimestamp(val, tz=datetime.timezone.utc) .time() # datetime.time() strips tzinfo... .replace(tzinfo=datetime.timezone.utc) ) - decoded = interchange.decode(val) dt: datetime.datetime | datetime.date | datetime.time = ( interchange.dateparse(decoded, self.t) if isinstance(decoded, str) @@ -214,7 +226,7 @@ def __call__(self, val: tp.Any) -> datetime.time: dt = self.t() if dt.__class__ is self.t: - return dt + return dt # type: ignore[return-value] return self.t( hour=dt.hour, @@ -226,8 +238,11 @@ def __call__(self, val: tp.Any) -> datetime.time: ) -class TimeDeltaUnmarshaller(AbstractUnmarshaller[datetime.timedelta]): - def __call__(self, val: tp.Any) -> datetime.timedelta: +TimeDeltaT = tp.TypeVar("TimeDeltaT", bound=datetime.timedelta) + + +class TimeDeltaUnmarshaller(AbstractUnmarshaller[TimeDeltaT], tp.Generic[TimeDeltaT]): + def __call__(self, val: tp.Any) -> TimeDeltaT: if isinstance(val, (int, float)): return self.t(seconds=int(val)) @@ -239,13 +254,16 @@ def __call__(self, val: tp.Any) -> datetime.timedelta: ) if td.__class__ is self.t: - return td + return td # type: ignore[return-value] return self.t(seconds=td.total_seconds()) -class UUIDUnmarshaller(AbstractUnmarshaller[uuid.UUID]): - def __call__(self, val: tp.Any) -> uuid.UUID: +UUIDT = tp.TypeVar("UUIDT", bound=uuid.UUID) + + +class UUIDUnmarshaller(AbstractUnmarshaller[UUIDT], tp.Generic[UUIDT]): + def __call__(self, val: tp.Any) -> UUIDT: decoded = ( interchange.strload(val) if inspection.istexttype(val.__class__) else val ) @@ -254,10 +272,13 @@ def __call__(self, val: tp.Any) -> uuid.UUID: return self.t(decoded) -class PatternUnmarshaller(AbstractUnmarshaller[re.Pattern]): - def __call__(self, val: tp.Any) -> re.Pattern: +PatternT = tp.TypeVar("PatternT", bound=re.Pattern) + + +class PatternUnmarshaller(AbstractUnmarshaller[PatternT], tp.Generic[PatternT]): + def __call__(self, val: tp.Any) -> PatternT: decoded = interchange.decode(val) - return re.compile(decoded) + return re.compile(decoded) # type: ignore[return-value] class CastUnmarshaller(AbstractUnmarshaller[T]): @@ -283,18 +304,18 @@ def __call__(self, val: tp.Any) -> T: MappingUnmarshaller = CastUnmarshaller[tp.Mapping] IterableUnmarshaller = CastUnmarshaller[tp.Iterable] -_LTVT = tp.TypeVarTuple("_LTVT") -_LT = tp.TypeVar("_LT") + +LiteralT = tp.TypeVar("LiteralT") -class LiteralUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]): +class LiteralUnmarshaller(AbstractUnmarshaller[LiteralT], tp.Generic[LiteralT]): __slots__ = ("values",) - def __init__(self, t: type[T], context: ContextT): + def __init__(self, t: type[LiteralT], context: ContextT): super().__init__(t, context) self.values = inspection.get_args(t) - def __call__(self, val: tp.Any) -> tp.Literal[*_LTVT]: # type: ignore[valid-type] + def __call__(self, val: tp.Any) -> LiteralT: if val in self.values: return val decoded = ( @@ -306,15 +327,18 @@ def __call__(self, val: tp.Any) -> tp.Literal[*_LTVT]: # type: ignore[valid-typ raise ValueError(f"{decoded!r} is not one of {self.values!r}") -class UnionUnmarshaller(AbstractUnmarshaller, tp.Generic[*_LTVT]): +UnionT = tp.TypeVar("UnionT") + + +class UnionUnmarshaller(AbstractUnmarshaller[UnionT], tp.Generic[UnionT]): __slots__ = ("stack", "ordered_routines") - def __init__(self, t: type[T], context: ContextT): + def __init__(self, t: type[UnionT], context: ContextT): super().__init__(t, context) self.stack = inspection.get_args(t) self.ordered_routines = [self.context[typ] for typ in self.stack] - def __call__(self, val: tp.Any) -> tp.Union[*_LTVT]: # type: ignore[valid-type] + def __call__(self, val: tp.Any) -> UnionT: for routine in self.ordered_routines: with contextlib.suppress(ValueError, TypeError, SyntaxError): unmarshalled = routine(val) @@ -327,19 +351,24 @@ def __call__(self, val: tp.Any) -> tp.Union[*_LTVT]: # type: ignore[valid-type] _VT = tp.TypeVar("_VT") -class SubscriptedMappingUnmarshaller(AbstractUnmarshaller[tp.Mapping[_KT, _VT]]): +MappingT = tp.TypeVar("MappingT", bound=tp.Mapping) + + +class SubscriptedMappingUnmarshaller( + AbstractUnmarshaller[MappingT], tp.Generic[MappingT] +): __slots__ = ( "keys", "values", ) - def __init__(self, t: type[tp.Mapping[_KT, _VT]], context: ContextT) -> None: + def __init__(self, t: type[MappingT], context: ContextT): super().__init__(t=t, context=context) key_t, value_t = inspection.get_args(t) self.keys = context[key_t] self.values = context[value_t] - def __call__(self, val: tp.Any) -> tp.Mapping[_KT, _VT]: + def __call__(self, val: tp.Any) -> MappingT: # Always decode bytes. decoded = interchange.strload(val) keys = self.keys @@ -349,34 +378,45 @@ def __call__(self, val: tp.Any) -> tp.Mapping[_KT, _VT]: ) -class SubscriptedIterableUnmarshaller(AbstractUnmarshaller[tp.Iterable[_VT]]): +IterableT = tp.TypeVar("IterableT", bound=tp.Iterable) + + +class SubscriptedIterableUnmarshaller( + AbstractUnmarshaller[IterableT], tp.Generic[IterableT] +): __slots__ = ("values",) - def __init__(self, t: type[tp.Iterable[_VT]], context: ContextT) -> None: + def __init__(self, t: type[IterableT], context: ContextT): super().__init__(t=t, context=context) (value_t,) = inspection.get_args(t) self.values = context[value_t] - def __call__(self, val: tp.Any) -> tp.Iterable[_VT]: + def __call__(self, val: tp.Any) -> IterableT: # Always decode bytes. decoded = interchange.strload(val) values = self.values return self.origin((values(v) for v in interchange.itervalues(decoded))) -class SubscriptedIteratorUnmarshaller(AbstractUnmarshaller[tp.Iterator[_VT]]): +IteratorT = tp.TypeVar("IteratorT", bound=tp.Iterator) + + +class SubscriptedIteratorUnmarshaller( + AbstractUnmarshaller[IteratorT], tp.Generic[IteratorT] +): __slots__ = ("values",) - def __init__(self, t: type[tp.Iterator[_VT]], context: ContextT) -> None: + def __init__(self, t: type[IteratorT], context: ContextT): super().__init__(t=t, context=context) (value_t,) = inspection.get_args(t) self.values = context[value_t] - def __call__(self, val: tp.Any) -> tp.Iterator[_VT]: + def __call__(self, val: tp.Any) -> IteratorT: # Always decode bytes. decoded = interchange.strload(val) values = self.values - return (values(v) for v in interchange.itervalues(decoded)) + it: IteratorT = (values(v) for v in interchange.itervalues(decoded)) # type: ignore[assignment] + return it _TVT = tp.TypeVarTuple("_TVT") @@ -385,7 +425,7 @@ def __call__(self, val: tp.Any) -> tp.Iterator[_VT]: class FixedTupleUnmarshaller(AbstractUnmarshaller[tuple[*_TVT]]): __slots__ = ("ordered_routines", "stack") - def __init__(self, t: type[tuple[*_TVT]], context: ContextT) -> None: + def __init__(self, t: type[tuple[*_TVT]], context: ContextT): super().__init__(t, context) self.stack = inspection.get_args(t) self.ordered_routines = [self.context[vt] for vt in self.stack] @@ -406,7 +446,7 @@ def __call__(self, val: tp.Any) -> tuple[*_TVT]: class StructuredTypeUnmarshaller(AbstractUnmarshaller[_ST]): __slots__ = ("fields_by_var",) - def __init__(self, t: type[_ST], context: ContextT) -> None: + def __init__(self, t: type[_ST], context: ContextT): super().__init__(t, context) self.fields_by_var = {m.var: m for m in self.context.values()}