diff --git a/src/typelib/ctx.py b/src/typelib/ctx.py index fad9067..0af5d6b 100644 --- a/src/typelib/ctx.py +++ b/src/typelib/ctx.py @@ -2,22 +2,30 @@ from __future__ import annotations +import contextlib import typing as tp import typing_extensions as te -from typelib.py import refs +from typelib.py import inspection, refs ValueT = tp.TypeVar("ValueT") +DefaultT = tp.TypeVar("DefaultT") KeyT = te.TypeAliasType("KeyT", "type | refs.ForwardRef") class TypeContext(dict[KeyT, ValueT], tp.Generic[ValueT]): """A key-value mapping which can map between forward references and real types.""" + def get(self, key: KeyT, default: ValueT | DefaultT = None) -> ValueT | DefaultT: + with contextlib.suppress(KeyError): + return self[key] + + return default + def __missing__(self, key: type | refs.ForwardRef): """Hook to handle missing type references. - Allows for sharing lookup results between forward references and real types. + Allows for sharing lookup results between forward references, type aliases, real types. Args: key: The type or reference. @@ -26,5 +34,12 @@ def __missing__(self, key: type | refs.ForwardRef): if isinstance(key, refs.ForwardRef): raise KeyError(key) + unwrapped = inspection.unwrap(key) + if unwrapped in self: + val = self[unwrapped] + # Store the value at the original key to short-circuit in future + self[key] = val + return val + ref = refs.forwardref(key) return self[ref] diff --git a/src/typelib/marshals/routines.py b/src/typelib/marshals/routines.py index 93774d8..a95addc 100644 --- a/src/typelib/marshals/routines.py +++ b/src/typelib/marshals/routines.py @@ -469,9 +469,9 @@ def _fields_by_var(self): m = self.context.get(hint) or self.context.get(resolved) if m is None: warnings.warn( - "Failed to identify an unmarshaller for the associated type-variable pair: " + "Failed to identify a marshaller for the associated type-variable pair: " f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.", - stacklevel=4, + stacklevel=5, ) fields_by_var[name] = NoOpMarshaller(hint, self.context, var=name) continue diff --git a/src/typelib/py/inspection.py b/src/typelib/py/inspection.py index 85a68c6..fff5a5c 100644 --- a/src/typelib/py/inspection.py +++ b/src/typelib/py/inspection.py @@ -933,7 +933,6 @@ def isclassvartype(obj: type) -> bool: _UNWRAPPABLE = ( isclassvartype, - isoptionaltype, isfinal, ) @@ -1484,6 +1483,22 @@ def istypealiastype(t: tp.Any) -> compat.TypeIs[compat.TypeAliasType]: return isinstance(t, compat.TypeAliasType) +@compat.cache +def unwrap(t: tp.Any) -> tp.Any: + while True: + if should_unwrap(t): + t = t.__args__[0] + continue + if istypealiastype(t): + t = t.__value__ + continue + + if hasattr(t, "__supertype__"): + t = t.__supertype__ + continue + return t + + def _safe_issubclass(__cls: type, __class_or_tuple: type | tuple[type, ...]) -> bool: try: return issubclass(__cls, __class_or_tuple) diff --git a/src/typelib/unmarshals/routines.py b/src/typelib/unmarshals/routines.py index 03d6a04..7f6f55c 100644 --- a/src/typelib/unmarshals/routines.py +++ b/src/typelib/unmarshals/routines.py @@ -985,7 +985,7 @@ def _fields_by_var(self): warnings.warn( "Failed to identify an unmarshaller for the associated type-variable pair: " f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.", - stacklevel=4, + stacklevel=6, ) fields_by_var[name] = NoOpUnmarshaller(hint, self.context, var=name) continue diff --git a/tests/unit/marshals/test_api.py b/tests/unit/marshals/test_api.py index 7209f3f..8c21172 100644 --- a/tests/unit/marshals/test_api.py +++ b/tests/unit/marshals/test_api.py @@ -155,6 +155,11 @@ ), expected_output={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}}, ), + nested_type_alias=dict( + given_type=models.NestedTypeAliasType, + given_input=models.NestedTypeAliasType(alias=[1]), + expected_output={"alias": [1]}, + ), ) def test_marshal(given_type, given_input, expected_output): # When diff --git a/tests/unit/py/test_inspection.py b/tests/unit/py/test_inspection.py index 4fbda76..e743948 100644 --- a/tests/unit/py/test_inspection.py +++ b/tests/unit/py/test_inspection.py @@ -592,7 +592,6 @@ def test_isclassvartype(given_type, expected_is_classvar_type): @pytest.mark.suite( classvar=dict(given_type=t.ClassVar[int], expected_should_unwrap=True), - optional=dict(given_type=t.Optional[str], expected_should_unwrap=True), final=dict(given_type=t.Final[str], expected_should_unwrap=True), literal=dict(given_type=t.Literal[1], expected_should_unwrap=False), ) diff --git a/tests/unit/unmarshals/test_api.py b/tests/unit/unmarshals/test_api.py index c00dfaa..9297d8a 100644 --- a/tests/unit/unmarshals/test_api.py +++ b/tests/unit/unmarshals/test_api.py @@ -162,6 +162,11 @@ child=models.Child(intersection=models.ChildIntersect(b=0)), ), ), + nested_type_alias=dict( + given_type=models.NestedTypeAliasType, + given_input={"alias": ["1"]}, + expected_output=models.NestedTypeAliasType(alias=[1]), + ), ) def test_unmarshal(given_type, given_input, expected_output): # When