Skip to content

Commit

Permalink
fix: Handle TypeAliasType within TypeContext lookups
Browse files Browse the repository at this point in the history
- Resolves an issue where nested type alias types would cause a KeyError with direct lookups from the type context
  • Loading branch information
seandstewart committed Oct 30, 2024
1 parent d4e8a25 commit 1295bed
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 7 deletions.
19 changes: 17 additions & 2 deletions src/typelib/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,30 @@

from __future__ import annotations

import contextlib
import typing as tp
import typing_extensions as te

from typelib.py import refs
from typelib.py import inspection, refs

ValueT = tp.TypeVar("ValueT")
DefaultT = tp.TypeVar("DefaultT")
KeyT = te.TypeAliasType("KeyT", "type | refs.ForwardRef")


class TypeContext(dict[KeyT, ValueT], tp.Generic[ValueT]):
"""A key-value mapping which can map between forward references and real types."""

def get(self, key: KeyT, default: ValueT | DefaultT = None) -> ValueT | DefaultT:
with contextlib.suppress(KeyError):
return self[key]

return default

def __missing__(self, key: type | refs.ForwardRef):
"""Hook to handle missing type references.
Allows for sharing lookup results between forward references and real types.
Allows for sharing lookup results between forward references, type aliases, real types.
Args:
key: The type or reference.
Expand All @@ -26,5 +34,12 @@ def __missing__(self, key: type | refs.ForwardRef):
if isinstance(key, refs.ForwardRef):
raise KeyError(key)

unwrapped = inspection.unwrap(key)
if unwrapped in self:
val = self[unwrapped]
# Store the value at the original key to short-circuit in future
self[key] = val
return val

ref = refs.forwardref(key)
return self[ref]
4 changes: 2 additions & 2 deletions src/typelib/marshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def _fields_by_var(self):
m = self.context.get(hint) or self.context.get(resolved)
if m is None:
warnings.warn(
"Failed to identify an unmarshaller for the associated type-variable pair: "
"Failed to identify a marshaller for the associated type-variable pair: "
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
stacklevel=4,
stacklevel=5,
)
fields_by_var[name] = NoOpMarshaller(hint, self.context, var=name)
continue
Expand Down
17 changes: 16 additions & 1 deletion src/typelib/py/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,6 @@ def isclassvartype(obj: type) -> bool:

_UNWRAPPABLE = (
isclassvartype,
isoptionaltype,
isfinal,
)

Expand Down Expand Up @@ -1484,6 +1483,22 @@ def istypealiastype(t: tp.Any) -> compat.TypeIs[compat.TypeAliasType]:
return isinstance(t, compat.TypeAliasType)


@compat.cache
def unwrap(t: tp.Any) -> tp.Any:
while True:
if should_unwrap(t):
t = t.__args__[0]
continue
if istypealiastype(t):
t = t.__value__
continue

if hasattr(t, "__supertype__"):
t = t.__supertype__
continue
return t


def _safe_issubclass(__cls: type, __class_or_tuple: type | tuple[type, ...]) -> bool:
try:
return issubclass(__cls, __class_or_tuple)
Expand Down
2 changes: 1 addition & 1 deletion src/typelib/unmarshals/routines.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def _fields_by_var(self):
warnings.warn(
"Failed to identify an unmarshaller for the associated type-variable pair: "
f"Original ref: {hint}, Resolved ref: {resolved}. Will default to no-op.",
stacklevel=4,
stacklevel=6,
)
fields_by_var[name] = NoOpUnmarshaller(hint, self.context, var=name)
continue
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/marshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@
),
expected_output={"intersection": {"a": 0}, "child": {"intersection": {"b": 0}}},
),
nested_type_alias=dict(
given_type=models.NestedTypeAliasType,
given_input=models.NestedTypeAliasType(alias=[1]),
expected_output={"alias": [1]},
),
)
def test_marshal(given_type, given_input, expected_output):
# When
Expand Down
1 change: 0 additions & 1 deletion tests/unit/py/test_inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,6 @@ def test_isclassvartype(given_type, expected_is_classvar_type):

@pytest.mark.suite(
classvar=dict(given_type=t.ClassVar[int], expected_should_unwrap=True),
optional=dict(given_type=t.Optional[str], expected_should_unwrap=True),
final=dict(given_type=t.Final[str], expected_should_unwrap=True),
literal=dict(given_type=t.Literal[1], expected_should_unwrap=False),
)
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/unmarshals/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@
child=models.Child(intersection=models.ChildIntersect(b=0)),
),
),
nested_type_alias=dict(
given_type=models.NestedTypeAliasType,
given_input={"alias": ["1"]},
expected_output=models.NestedTypeAliasType(alias=[1]),
),
)
def test_unmarshal(given_type, given_input, expected_output):
# When
Expand Down

0 comments on commit 1295bed

Please sign in to comment.