Skip to content

Backport get_origin() and get_args() #698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions typing_extensions/src_py3/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,8 @@ def test_typing_extensions_defers_when_possible(self):
'Final',
'get_type_hints'
}
if sys.version_info[:2] == (3, 8):
exclude |= {'get_args', 'get_origin'}
for item in typing_extensions.__all__:
if item not in exclude and hasattr(typing, item):
self.assertIs(
Expand Down
52 changes: 51 additions & 1 deletion typing_extensions/src_py3/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _check_methods_in_mro(C, *methods):
HAVE_ANNOTATED = PEP_560 or SUBS_TREE

if PEP_560:
__all__.append("get_type_hints")
__all__.extend(["get_args", "get_origin", "get_type_hints"])

if HAVE_ANNOTATED:
__all__.append("Annotated")
Expand Down Expand Up @@ -1992,3 +1992,53 @@ class Annotated(metaclass=AnnotatedMeta):
OptimizedList = Annotated[List[T], runtime.Optimize()]
OptimizedList[int] == Annotated[List[int], runtime.Optimize()]
"""

# Python 3.8 has get_origin() and get_args() but those implementations aren't
# Annotated-aware, so we can't use those, only Python 3.9 versions will do.
if sys.version_info[:2] >= (3, 9):
get_origin = typing.get_origin
get_args = typing.get_args
elif PEP_560:
from typing import _GenericAlias # noqa

def get_origin(tp):
"""Get the unsubscripted version of a type.

This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar
and Annotated. Return None for unsupported types. Examples::

get_origin(Literal[42]) is Literal
get_origin(int) is None
get_origin(ClassVar[int]) is ClassVar
get_origin(Generic) is Generic
get_origin(Generic[T]) is Generic
get_origin(Union[T, int]) is Union
get_origin(List[Tuple[T, T]][int]) == list
"""
if isinstance(tp, _AnnotatedAlias):
return Annotated
if isinstance(tp, _GenericAlias):
return tp.__origin__
if tp is Generic:
return Generic
return None

def get_args(tp):
"""Get type arguments with all substitutions performed.

For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
if isinstance(tp, _AnnotatedAlias):
return (tp.__origin__,) + tp.__metadata__
if isinstance(tp, _GenericAlias):
res = tp.__args__
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
return ()