diff --git a/src/typelib/graph.py b/src/typelib/graph.py index 391f1c7..9f9b20a 100644 --- a/src/typelib/graph.py +++ b/src/typelib/graph.py @@ -112,6 +112,9 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]: resolve one level deep on each attempt, otherwise we will find ourselves stuck in a closed loop which never terminates (infinite recursion). """ + if inspection.istypealiastype(t): + t = t.__value__ + graph: graphlib.TopologicalSorter = graphlib.TopologicalSorter() root = TypeNode(t) stack = collections.deque([root]) @@ -127,6 +130,8 @@ def get_type_graph(t: type) -> graphlib.TopologicalSorter[TypeNode]: # If no type was provided, there's no reason to do further processing. if child in (constants.empty, typing.Any): continue + if inspection.istypealiastype(child): + child = child.__value__ # Only subscripted generics or non-stdlib types can be cyclic. # i.e., we may get `str` or `datetime` any number of times, diff --git a/tests/models.py b/tests/models.py index 13622b3..6b6b455 100644 --- a/tests/models.py +++ b/tests/models.py @@ -5,6 +5,8 @@ import enum import typing +from typelib.py import compat + @dataclasses.dataclass class RecursiveType: @@ -83,3 +85,11 @@ class ParentIntersect: @dataclasses.dataclass class ChildIntersect: b: int + + +ListAlias = compat.TypeAliasType("ListAlias", list[int]) + + +@dataclasses.dataclass +class NestedTypeAliasType: + alias: ListAlias diff --git a/tests/unit/test_graph.py b/tests/unit/test_graph.py index a4993b5..5c410fa 100644 --- a/tests/unit/test_graph.py +++ b/tests/unit/test_graph.py @@ -9,6 +9,9 @@ from typelib import graph from typelib.py import refs +from tests import models +from tests.models import NestedTypeAliasType + @dataclasses.dataclass class Simple: @@ -87,6 +90,14 @@ class NoTypes: ], ), any_type=dict(given_type=NoTypes, expected_nodes=[graph.TypeNode(type=NoTypes)]), + nested_type_alias=dict( + given_type=models.NestedTypeAliasType, + expected_nodes=[ + graph.TypeNode(type=int), + graph.TypeNode(type=list[int], var="alias"), + graph.TypeNode(type=NestedTypeAliasType), + ], + ), ) @pytest.mark.skipif(sys.version_info < (3, 10), reason="py3.10+") def test_static_order(given_type, expected_nodes):