Skip to content

Commit 3ba4ffd

Browse files
author
Matej Spiller Muys
committed
Process NameTuple and fix Union types using |
1 parent f85dfa1 commit 3ba4ffd

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

mypy/stubgen.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
TypeList,
134134
TypeStrVisitor,
135135
UnboundType,
136+
UnionType,
136137
get_proper_type,
137138
)
138139
from mypy.visitor import NodeVisitor
@@ -303,6 +304,11 @@ def visit_unbound_type(self, t: UnboundType) -> str:
303304
s += f"[{self.args_str(t.args)}]"
304305
return s
305306

307+
def visit_union_type(self, t: UnionType) -> str:
308+
s = super().visit_union_type(t)
309+
self.stubgen.import_tracker.require_name("Union")
310+
return s
311+
306312
def visit_none_type(self, t: NoneType) -> str:
307313
return "None"
308314

@@ -599,6 +605,7 @@ def __init__(
599605
self.export_less = export_less
600606
# Add imports that could be implicitly generated
601607
self.import_tracker.add_import_from("typing", [("NamedTuple", None)])
608+
self.import_tracker.add_import_from("typing", [("Union", None)])
602609
# Names in __all__ are required
603610
for name in _all_ or ():
604611
if name not in IGNORED_DUNDERS:
@@ -1017,18 +1024,24 @@ def is_namedtuple(self, expr: Expression) -> bool:
10171024
if not isinstance(expr, CallExpr):
10181025
return False
10191026
callee = expr.callee
1020-
return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or (
1021-
isinstance(callee, MemberExpr) and callee.name == "namedtuple"
1027+
return (
1028+
isinstance(callee, NameExpr)
1029+
and (callee.name.endswith("namedtuple") or callee.name.endswith("NamedTuple"))
1030+
) or (
1031+
isinstance(callee, MemberExpr)
1032+
and (callee.name == "namedtuple" or callee.name == "NamedTuple")
10221033
)
10231034

10241035
def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
10251036
if self._state != EMPTY:
10261037
self.add("\n")
10271038
if isinstance(rvalue.args[1], StrExpr):
1028-
items = rvalue.args[1].value.replace(",", " ").split()
1039+
items: list[tuple[str, str | None] | None] = [
1040+
(key, "Incomplete") for key in rvalue.args[1].value.replace(",", " ").split()
1041+
]
10291042
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
10301043
list_items = cast(List[StrExpr], rvalue.args[1].items)
1031-
items = [item.value for item in list_items]
1044+
items = [self.process_namedtuple_type(item) for item in list_items]
10321045
else:
10331046
self.add(f"{self._indent}{lvalue.name}: Incomplete")
10341047
self.import_tracker.require_name("Incomplete")
@@ -1041,9 +1054,20 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
10411054
self.import_tracker.require_name("Incomplete")
10421055
self.add("\n")
10431056
for item in items:
1044-
self.add(f"{self._indent} {item}: Incomplete\n")
1057+
if item is None:
1058+
continue
1059+
key, rtype = item
1060+
self.add(f"{self._indent} {key}: {rtype}\n")
10451061
self._state = CLASS
10461062

1063+
def process_namedtuple_type(self, item: StrExpr | TupleExpr) -> tuple[str, str | None] | None:
1064+
if isinstance(item, StrExpr):
1065+
return item.value, "Incomplete"
1066+
elif isinstance(item.items[0], StrExpr):
1067+
p = AliasPrinter(self)
1068+
return item.items[0].value, item.items[1].accept(p)
1069+
return None
1070+
10471071
def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
10481072
"""Return True for things that look like target for an alias.
10491073

test-data/unit/stubgen.test

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,17 @@ __all__ = ['urllib']
516516
[out]
517517
import urllib as urllib
518518

519+
520+
[case testImportUnion]
521+
def func(a: str | int) -> str | int:
522+
pass
523+
524+
[out]
525+
from typing import Union
526+
527+
def func(a: Union[str, int]) -> Union[str, int]: ...
528+
529+
519530
[case testRelativeImportAll]
520531
from .x import *
521532
[out]
@@ -595,6 +606,17 @@ class X(NamedTuple):
595606
a: Incomplete
596607
b: Incomplete
597608

609+
[case testNamedtupleWithTypes]
610+
import collections, x
611+
X = typing.NamedTuple('X', [('a', str), ('b', str)])
612+
[out]
613+
from _typeshed import Incomplete
614+
from typing import NamedTuple
615+
616+
class X(NamedTuple):
617+
a: str
618+
b: str
619+
598620
[case testEmptyNamedtuple]
599621
import collections
600622
X = collections.namedtuple('X', [])
@@ -915,7 +937,7 @@ T = TypeVar('T')
915937
alias = Union[T, List[T]]
916938

917939
[out]
918-
from typing import TypeVar
940+
from typing import TypeVar, Union
919941

920942
T = TypeVar('T')
921943
alias = Union[T, List[T]]

0 commit comments

Comments
 (0)