Skip to content

Commit e4c35ce

Browse files
author
Matej Spiller Muys
committed
Process NameTuple and fix Union types using |
1 parent f85dfa1 commit e4c35ce

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

mypy/stubgen.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
TypeList,
134134
TypeStrVisitor,
135135
UnboundType,
136+
UnionType,
136137
get_proper_type,
137138
)
138139
from mypy.visitor import NodeVisitor
@@ -303,6 +304,11 @@ def visit_unbound_type(self, t: UnboundType) -> str:
303304
s += f"[{self.args_str(t.args)}]"
304305
return s
305306

307+
def visit_union_type(self, t: UnionType) -> str:
308+
s = super().visit_union_type(t)
309+
self.stubgen.import_tracker.require_name("Union")
310+
return s
311+
306312
def visit_none_type(self, t: NoneType) -> str:
307313
return "None"
308314

@@ -599,6 +605,7 @@ def __init__(
599605
self.export_less = export_less
600606
# Add imports that could be implicitly generated
601607
self.import_tracker.add_import_from("typing", [("NamedTuple", None)])
608+
self.import_tracker.add_import_from("typing", [("Union", None)])
602609
# Names in __all__ are required
603610
for name in _all_ or ():
604611
if name not in IGNORED_DUNDERS:
@@ -1017,33 +1024,51 @@ def is_namedtuple(self, expr: Expression) -> bool:
10171024
if not isinstance(expr, CallExpr):
10181025
return False
10191026
callee = expr.callee
1020-
return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or (
1021-
isinstance(callee, MemberExpr) and callee.name == "namedtuple"
1027+
return (
1028+
isinstance(callee, NameExpr)
1029+
and (callee.name.endswith("namedtuple") or callee.name.endswith("NamedTuple"))
1030+
) or (
1031+
isinstance(callee, MemberExpr)
1032+
and (callee.name == "namedtuple" or callee.name == "NamedTuple")
10221033
)
10231034

10241035
def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
10251036
if self._state != EMPTY:
10261037
self.add("\n")
10271038
if isinstance(rvalue.args[1], StrExpr):
1028-
items = rvalue.args[1].value.replace(",", " ").split()
1039+
items: list[tuple[str | None, str | None]] = [
1040+
(key, "Incomplete") for key in rvalue.args[1].value.replace(",", " ").split()
1041+
]
10291042
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
10301043
list_items = cast(List[StrExpr], rvalue.args[1].items)
1031-
items = [item.value for item in list_items]
1044+
items = [self.process_namedtuple_type(item) for item in list_items]
10321045
else:
10331046
self.add(f"{self._indent}{lvalue.name}: Incomplete")
10341047
self.import_tracker.require_name("Incomplete")
10351048
return
10361049
self.import_tracker.require_name("NamedTuple")
10371050
self.add(f"{self._indent}class {lvalue.name}(NamedTuple):")
1038-
if len(items) == 0:
1051+
1052+
validitems = list(filter(lambda v: v[0] is not None and v[0].isidentifier(), items))
1053+
1054+
if len(validitems) == 0:
10391055
self.add(" ...\n")
10401056
else:
10411057
self.import_tracker.require_name("Incomplete")
10421058
self.add("\n")
1043-
for item in items:
1044-
self.add(f"{self._indent} {item}: Incomplete\n")
1059+
for item in validitems:
1060+
key, rtype = item
1061+
self.add(f"{self._indent} {key}: {rtype}\n")
10451062
self._state = CLASS
10461063

1064+
def process_namedtuple_type(self, item: StrExpr | TupleExpr) -> tuple[str | None, str | None]:
1065+
if isinstance(item, StrExpr):
1066+
return item.value, "Incomplete"
1067+
elif isinstance(item.items[0], StrExpr):
1068+
p = AliasPrinter(self)
1069+
return item.items[0].value, item.items[1].accept(p)
1070+
return None, None
1071+
10471072
def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
10481073
"""Return True for things that look like target for an alias.
10491074

test-data/unit/stubgen.test

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,17 @@ __all__ = ['urllib']
516516
[out]
517517
import urllib as urllib
518518

519+
520+
[case testImportUnion]
521+
def func(a: str | int) -> str | int:
522+
pass
523+
524+
[out]
525+
from typing import Union
526+
527+
def func(a: Union[str, int]) -> Union[str, int]: ...
528+
529+
519530
[case testRelativeImportAll]
520531
from .x import *
521532
[out]
@@ -595,6 +606,33 @@ class X(NamedTuple):
595606
a: Incomplete
596607
b: Incomplete
597608

609+
[case testNamedtupleUsingInvalidIdent]
610+
import collections, x
611+
X = collections.namedtuple('X', ['@'])
612+
[out]
613+
from typing import NamedTuple
614+
615+
class X(NamedTuple): ...
616+
617+
[case testNamedtupleWithTypesInvalidIdent]
618+
import collections, x
619+
X = typing.NamedTuple('X', [('@', int)])
620+
[out]
621+
from typing import NamedTuple
622+
623+
class X(NamedTuple): ...
624+
625+
[case testNamedtupleWithTypes]
626+
import collections, x
627+
X = typing.NamedTuple('X', [('a', str), ('b', str)])
628+
[out]
629+
from _typeshed import Incomplete
630+
from typing import NamedTuple
631+
632+
class X(NamedTuple):
633+
a: str
634+
b: str
635+
598636
[case testEmptyNamedtuple]
599637
import collections
600638
X = collections.namedtuple('X', [])
@@ -915,7 +953,7 @@ T = TypeVar('T')
915953
alias = Union[T, List[T]]
916954

917955
[out]
918-
from typing import TypeVar
956+
from typing import TypeVar, Union
919957

920958
T = TypeVar('T')
921959
alias = Union[T, List[T]]

0 commit comments

Comments
 (0)