133
133
TypeList ,
134
134
TypeStrVisitor ,
135
135
UnboundType ,
136
+ UnionType ,
136
137
get_proper_type ,
137
138
)
138
139
from mypy .visitor import NodeVisitor
@@ -303,6 +304,11 @@ def visit_unbound_type(self, t: UnboundType) -> str:
303
304
s += f"[{ self .args_str (t .args )} ]"
304
305
return s
305
306
307
+ def visit_union_type (self , t : UnionType ) -> str :
308
+ s = super ().visit_union_type (t )
309
+ self .stubgen .import_tracker .require_name ("Union" )
310
+ return s
311
+
306
312
def visit_none_type (self , t : NoneType ) -> str :
307
313
return "None"
308
314
@@ -599,6 +605,7 @@ def __init__(
599
605
self .export_less = export_less
600
606
# Add imports that could be implicitly generated
601
607
self .import_tracker .add_import_from ("typing" , [("NamedTuple" , None )])
608
+ self .import_tracker .add_import_from ("typing" , [("Union" , None )])
602
609
# Names in __all__ are required
603
610
for name in _all_ or ():
604
611
if name not in IGNORED_DUNDERS :
@@ -1017,18 +1024,24 @@ def is_namedtuple(self, expr: Expression) -> bool:
1017
1024
if not isinstance (expr , CallExpr ):
1018
1025
return False
1019
1026
callee = expr .callee
1020
- return (isinstance (callee , NameExpr ) and callee .name .endswith ("namedtuple" )) or (
1021
- isinstance (callee , MemberExpr ) and callee .name == "namedtuple"
1027
+ return (
1028
+ isinstance (callee , NameExpr )
1029
+ and (callee .name .endswith ("namedtuple" ) or callee .name .endswith ("NamedTuple" ))
1030
+ ) or (
1031
+ isinstance (callee , MemberExpr )
1032
+ and (callee .name == "namedtuple" or callee .name == "NamedTuple" )
1022
1033
)
1023
1034
1024
1035
def process_namedtuple (self , lvalue : NameExpr , rvalue : CallExpr ) -> None :
1025
1036
if self ._state != EMPTY :
1026
1037
self .add ("\n " )
1027
1038
if isinstance (rvalue .args [1 ], StrExpr ):
1028
- items = rvalue .args [1 ].value .replace ("," , " " ).split ()
1039
+ items : list [tuple [str , str | None ] | None ] = [
1040
+ (key , "Incomplete" ) for key in rvalue .args [1 ].value .replace ("," , " " ).split ()
1041
+ ]
1029
1042
elif isinstance (rvalue .args [1 ], (ListExpr , TupleExpr )):
1030
1043
list_items = cast (List [StrExpr ], rvalue .args [1 ].items )
1031
- items = [item . value for item in list_items ]
1044
+ items = [self . process_namedtuple_type ( item ) for item in list_items ]
1032
1045
else :
1033
1046
self .add (f"{ self ._indent } { lvalue .name } : Incomplete" )
1034
1047
self .import_tracker .require_name ("Incomplete" )
@@ -1041,9 +1054,20 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
1041
1054
self .import_tracker .require_name ("Incomplete" )
1042
1055
self .add ("\n " )
1043
1056
for item in items :
1044
- self .add (f"{ self ._indent } { item } : Incomplete\n " )
1057
+ if item is None :
1058
+ continue
1059
+ key , rtype = item
1060
+ self .add (f"{ self ._indent } { key } : { rtype } \n " )
1045
1061
self ._state = CLASS
1046
1062
1063
+ def process_namedtuple_type (self , item : StrExpr | TupleExpr ) -> tuple [str , str | None ] | None :
1064
+ if isinstance (item , StrExpr ):
1065
+ return item .value , "Incomplete"
1066
+ elif isinstance (item .items [0 ], StrExpr ):
1067
+ p = AliasPrinter (self )
1068
+ return item .items [0 ].value , item .items [1 ].accept (p )
1069
+ return None
1070
+
1047
1071
def is_alias_expression (self , expr : Expression , top_level : bool = True ) -> bool :
1048
1072
"""Return True for things that look like target for an alias.
1049
1073
0 commit comments