Skip to content

Commit

Permalink
Add import Union to stub file (python#12929)
Browse files Browse the repository at this point in the history
  • Loading branch information
SilvestrLanik committed Aug 15, 2022
1 parent c8a2289 commit 88ac62e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
19 changes: 19 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
TypeList,
TypeStrVisitor,
UnboundType,
UnionType,
get_proper_type,
)
from mypy.visitor import NodeVisitor
Expand Down Expand Up @@ -461,6 +462,17 @@ def import_lines(self) -> List[str]:
module_map: Mapping[str, List[str]] = defaultdict(list)

for name in sorted(self.required_names):
# We don't want to ignore Union even if it's not listed in the import statement because for PEP 604 style of
# Union we're still generating explicit Union e.g.
#
# def foo(a: int | str):
# print(a)
# ==>
# def foo(a: Union[int | str]): ...

if name == "Union":
self.module_for[name] = 'typing'

# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue
Expand Down Expand Up @@ -693,6 +705,8 @@ def visit_func_def(
# Luckily, an argument explicitly annotated with "Any" has
# type "UnboundType" and will not match.
if not isinstance(get_proper_type(annotated_type), AnyType):
if isinstance(get_proper_type(annotated_type), UnionType):
self.add_typing_import("Union")
annotation = f": {self.print_annotation(annotated_type)}"

if kind.is_named() and not any(arg.startswith("*") for arg in args):
Expand Down Expand Up @@ -722,6 +736,8 @@ def visit_func_def(
# type "UnboundType" and will enter the else branch.
retname = None # implicit Any
else:
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), UnionType):
self.add_typing_import("Union")
retname = self.print_annotation(o.unanalyzed_type.ret_type)
elif isinstance(o, FuncDef) and (
o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE
Expand Down Expand Up @@ -1200,6 +1216,9 @@ def get_init(
return None
self._vars[-1].append(lvalue)
if annotation is not None:
if isinstance(get_proper_type(annotation), UnionType):
self.add_typing_import("Union")

typename = self.print_annotation(annotation)
if (
isinstance(annotation, UnboundType)
Expand Down
25 changes: 25 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2705,3 +2705,28 @@ def f():
return 0
[out]
def f(): ...

[case testAddUnionImportForArgument]
def foo(a: str | None):
pass
[out]
from typing import Union

def foo(a: Union[str, None]): ...

[case testAddUnionImportForReturn]
def foo() -> str | None:
pass
[out]
from typing import Union

def foo() -> Union[str, None]: ...

[case testAddUnionImportForClassAttribute]
class A:
a: int | str
[out]
from typing import Union

class A:
a: Union[int, str]

0 comments on commit 88ac62e

Please sign in to comment.