@@ -20,6 +20,18 @@ def __repr__(self) -> str:
20
20
return self .name
21
21
22
22
23
+ ANY = Ref ("typing.Any" )
24
+ LIST = Ref ("builtins.list" )
25
+ TUPLE = Ref ("builtins.tuple" )
26
+ SET = Ref ("builtins.set" )
27
+ TYPE = Ref ("builtins.type" )
28
+ NONE_TYPE = Ref ("builtins.NoneType" )
29
+ BOOL = Ref ("builtins.bool" )
30
+ INT = Ref ("builtins.int" )
31
+ FLOAT = Ref ("builtins.float" )
32
+ STR = Ref ("builtins.str" )
33
+
34
+
23
35
@dataclass (frozen = True , slots = True )
24
36
class TypeVar :
25
37
name : str
@@ -132,25 +144,28 @@ def literal(value: int | str | bool | float | tuple | list | None) -> Literal:
132
144
case value if value is NULL :
133
145
ref = Ref ("builtins.ellipsis" )
134
146
case int ():
135
- ref = Ref ( "builtins.int" )
147
+ ref = INT
136
148
case float ():
137
- ref = Ref ( "builtins.float" )
149
+ ref = FLOAT
138
150
case str ():
139
- ref = Ref ( "builtins.str" )
151
+ ref = STR
140
152
case bool ():
141
- ref = Ref ( "builtins.bool" )
153
+ ref = BOOL
142
154
case None :
143
- ref = Ref ( "builtins.NoneType" )
155
+ ref = NONE_TYPE
144
156
case tuple ():
145
- ref = Ref ( "builtins.tuple" )
157
+ ref = TUPLE
146
158
case list ():
147
159
value = tuple (value )
148
- ref = Ref ( "builtins.list" )
160
+ ref = LIST
149
161
case _:
150
162
assert False , f"Unknown literal type { value !r} "
151
163
return Literal (value , ref )
152
164
153
165
166
+ NONE = literal (None )
167
+
168
+
154
169
@dataclass (frozen = True , slots = True )
155
170
class TypedDict :
156
171
items : frozenset [Row ]
@@ -227,9 +242,8 @@ def __repr__(self) -> str:
227
242
class SideEffect :
228
243
new : bool
229
244
bound_method : bool = False
230
- update : typing .Optional [TypeExpr ] = None
245
+ update : tuple [ typing .Optional [TypeExpr ], tuple [ int , ...]] = ( None , ())
231
246
points_to_args : bool = False
232
- name : typing .Optional [str ] = None # ad hoc effects
233
247
234
248
235
249
@dataclass (frozen = True , slots = True )
@@ -248,7 +262,7 @@ def __repr__(self) -> str:
248
262
new = "new " if self .new () else ""
249
263
update = (
250
264
"{update " + str (self .side_effect .update ) + "}@"
251
- if self .side_effect .update
265
+ if self .side_effect .update [ 0 ]
252
266
else ""
253
267
)
254
268
return f"[{ type_params } ]({ self .params } -> { update } { new } { self .return_type } )"
@@ -340,9 +354,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
340
354
case Row () as row :
341
355
return replace (row , type = bind_typevars (row .type , context ))
342
356
case SideEffect () as s :
343
- if s .update is None :
357
+ if s .update [ 0 ] is None :
344
358
return s
345
- return replace (s , update = bind_typevars (s .update , context ))
359
+ return replace (s , update = ( bind_typevars (s .update [ 0 ] , context ), s . update [ 1 ] ))
346
360
case Class (
347
361
type_params = type_params , class_dict = class_dict , inherits = inherits
348
362
) as klass :
@@ -372,9 +386,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
372
386
return choices .items [actual_arg .value ]
373
387
return Access (choices , actual_arg )
374
388
case SideEffect () as s :
375
- if s .update is None :
389
+ if s .update [ 0 ] is None :
376
390
return s
377
- return replace (s , update = bind_typevars (s .update , context ))
391
+ return replace (s , update = ( bind_typevars (s .update [ 0 ] , context ), s . update [ 1 ] ))
378
392
raise NotImplementedError (f"{ t !r} , { type (t )} " )
379
393
380
394
@@ -414,7 +428,6 @@ def union(items: typing.Iterable[TypeExpr], squeeze=True) -> TypeExpr:
414
428
415
429
TOP = typed_dict ([])
416
430
BOTTOM = Union (frozenset ())
417
- ANY = Ref ("typing.Any" )
418
431
419
432
420
433
def join (t1 : TypeExpr , t2 : TypeExpr ) -> TypeExpr :
@@ -461,7 +474,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
461
474
tuple (join (t1 , t2 ) for t1 , t2 in zip (l1 .value , l2 .value )),
462
475
l1 .ref ,
463
476
)
464
- if l1 .ref == Ref ( "builtins.list" ) :
477
+ if l1 .ref == LIST :
465
478
return Instantiation (
466
479
l1 .ref , (join_all ([* l1 .value , * l2 .value ]),)
467
480
)
@@ -477,7 +490,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
477
490
Instantiation () as inst ,
478
491
Literal (tuple () as value , ref = ref ),
479
492
) if ref == inst .generic :
480
- if ref . name == "builtins.list" :
493
+ if ref == LIST :
481
494
value = (join_all (value ),)
482
495
return join (inst , Instantiation (ref , value ))
483
496
case (TypedDict (items1 ), TypedDict (items2 )): # type: ignore
@@ -497,21 +510,17 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
497
510
if index1 == index2 :
498
511
return Row (index1 , join (t1 , t2 ))
499
512
return BOTTOM
500
- case (Row (_, t1 ), (Instantiation (Ref ("builtins .list" ), type_args ))) | (
501
- Instantiation (Ref ("builtins .list" ), type_args ),
513
+ case (Row (_, t1 ), (Instantiation (Ref ("builtings .list" ), type_args ))) | (
514
+ Instantiation (Ref ("builtings .list" ), type_args ),
502
515
Row (_, t1 ),
503
516
):
504
- return Instantiation (
505
- Ref ("builtins.list" ), tuple (join (t1 , t ) for t in type_args )
506
- )
517
+ return Instantiation (LIST , tuple (join (t1 , t ) for t in type_args ))
507
518
case (Row (_, t1 ), (Instantiation (Ref ("builtins.tuple" ), type_args ))) | (
508
519
Instantiation (Ref ("builtins.tuple" ), type_args ),
509
520
Row (_, t1 ),
510
521
):
511
522
# not exact; should only join at the index of the row
512
- return Instantiation (
513
- Ref ("builtins.tuple" ), tuple (join (t1 , t ) for t in type_args )
514
- )
523
+ return Instantiation (TUPLE , tuple (join (t1 , t ) for t in type_args ))
515
524
case Class (), Class ():
516
525
return TOP
517
526
case (Class (name = "int" ) | Ref ("builtins.int" ) as c , Literal (int ())) | (
@@ -520,10 +529,11 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
520
529
):
521
530
return c
522
531
case (SideEffect () as s1 , SideEffect () as s2 ):
532
+ assert s1 .update [1 ] == s2 .update [1 ]
523
533
return SideEffect (
524
534
new = s1 .new | s2 .new ,
525
535
bound_method = s1 .bound_method | s2 .bound_method ,
526
- update = join (s1 .update , s2 .update ),
536
+ update = ( join (s1 .update [ 0 ] , s2 .update [ 0 ]), s1 . update [ 1 ] ),
527
537
points_to_args = s1 .points_to_args | s2 .points_to_args ,
528
538
)
529
539
case x , y :
@@ -556,7 +566,7 @@ def meet(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
556
566
tuple (meet (t1 , t2 ) for t1 , t2 in zip (l1 .value , l2 .value )),
557
567
l1 .ref ,
558
568
)
559
- if l1 .ref == Ref ( "builtins.list" ) :
569
+ if l1 .ref == LIST :
560
570
return Instantiation (
561
571
l1 .ref , (meet_all ([* l1 .value , * l2 .value ]),)
562
572
)
@@ -1054,13 +1064,13 @@ def get_init_func(callable: TypeExpr) -> TypeExpr:
1054
1064
]
1055
1065
)
1056
1066
side_effect = SideEffect (
1057
- new = True , bound_method = True , update = None , points_to_args = True
1067
+ new = True , bound_method = True , update = ( None , ()) , points_to_args = True
1058
1068
)
1059
1069
res = bind_self (
1060
1070
overload (
1061
1071
replace (
1062
1072
f ,
1063
- return_type = f .side_effect .update or selftype ,
1073
+ return_type = f .side_effect .update [ 0 ] or selftype ,
1064
1074
side_effect = side_effect ,
1065
1075
)
1066
1076
for f in init .items
@@ -1268,7 +1278,7 @@ def make_list_constructor() -> Overloaded:
1268
1278
FunctionType (
1269
1279
params = typed_dict ([make_row (0 , "args" , args )]),
1270
1280
return_type = return_type ,
1271
- side_effect = SideEffect (new = True , points_to_args = True , name = "[]" ),
1281
+ side_effect = SideEffect (new = True , points_to_args = True ),
1272
1282
is_property = False ,
1273
1283
type_params = (args ,),
1274
1284
)
@@ -1277,13 +1287,13 @@ def make_list_constructor() -> Overloaded:
1277
1287
1278
1288
1279
1289
def make_set_constructor () -> Overloaded :
1280
- return_type = Instantiation (Ref ( "builtins.set" ) , (union ([]),))
1290
+ return_type = Instantiation (SET , (union ([]),))
1281
1291
return overload (
1282
1292
[
1283
1293
FunctionType (
1284
1294
params = typed_dict ([]),
1285
1295
return_type = return_type ,
1286
- side_effect = SideEffect (new = True , points_to_args = True , name = "{}" ),
1296
+ side_effect = SideEffect (new = True , points_to_args = True ),
1287
1297
is_property = False ,
1288
1298
type_params = (),
1289
1299
)
@@ -1293,13 +1303,13 @@ def make_set_constructor() -> Overloaded:
1293
1303
1294
1304
def make_tuple_constructor () -> Overloaded :
1295
1305
args = TypeVar ("Args" , is_args = True )
1296
- return_type = Instantiation (Ref ( "builtins.tuple" ) , (args ,))
1306
+ return_type = Instantiation (TUPLE , (args ,))
1297
1307
return overload (
1298
1308
[
1299
1309
FunctionType (
1300
1310
params = typed_dict ([make_row (0 , "args" , args )]),
1301
1311
return_type = return_type ,
1302
- side_effect = SideEffect (new = True , points_to_args = True , name = "()" ),
1312
+ side_effect = SideEffect (new = True , points_to_args = True ),
1303
1313
is_property = False ,
1304
1314
type_params = (args ,),
1305
1315
)
@@ -1309,8 +1319,6 @@ def make_tuple_constructor() -> Overloaded:
1309
1319
1310
1320
def make_slice_constructor () -> Overloaded :
1311
1321
return_type = Ref ("builtins.slice" )
1312
- NONE = literal (None )
1313
- INT = Ref ("builtins.int" )
1314
1322
both = union ([NONE , INT ])
1315
1323
return overload (
1316
1324
[
@@ -1319,7 +1327,7 @@ def make_slice_constructor() -> Overloaded:
1319
1327
[make_row (0 , "start" , both ), make_row (1 , "end" , both )]
1320
1328
),
1321
1329
return_type = return_type ,
1322
- side_effect = SideEffect (new = True , name = "[:]" ),
1330
+ side_effect = SideEffect (new = True ),
1323
1331
is_property = False ,
1324
1332
type_params = (),
1325
1333
)
@@ -1619,8 +1627,9 @@ def visit_Name(self, name) -> TypeExpr:
1619
1627
return self .symtable .lookup (name .id )
1620
1628
1621
1629
def visit_Starred (self , starred : ast .Starred ) -> TypeExpr :
1622
- assert isinstance (starred .value , ast .Name ), f"{ starred !r} "
1623
- return TypeVar (starred .value .id , is_args = True )
1630
+ if isinstance (starred .value , ast .Name ):
1631
+ return TypeVar (starred .value .id , is_args = True )
1632
+ return Star ((self .to_type (starred .value ),))
1624
1633
1625
1634
def visit_Subscript (self , subscr : ast .Subscript ) -> TypeExpr :
1626
1635
generic = self .to_type (subscr .value )
@@ -1677,7 +1686,7 @@ def is_immutable(value: TypeExpr) -> bool:
1677
1686
case Row (type = value ):
1678
1687
return is_immutable (value )
1679
1688
case FunctionType () as f :
1680
- if f .side_effect .update is not None :
1689
+ if f .side_effect .update [ 0 ] is not None :
1681
1690
return False
1682
1691
return True
1683
1692
case Ref (name ):
@@ -1805,21 +1814,21 @@ def visit_FunctionDef(self, fdef: ast.FunctionDef) -> FunctionType:
1805
1814
update = call_decorators .get ("update" )
1806
1815
if update is not None :
1807
1816
assert isinstance (update , ast .Call )
1808
- assert len (update .args ) == 1
1809
1817
update_arg = update .args [0 ]
1810
1818
if isinstance (update_arg , ast .Constant ) and isinstance (
1811
1819
update_arg .value , str
1812
1820
):
1813
1821
update_arg = ast .parse (update_arg .s ).body [0 ].value
1814
1822
update_type = self .expr_to_type (update_arg )
1823
+ update_args = tuple (self .expr_to_type (x ) for x in update .args [1 :])
1815
1824
else :
1816
1825
update_type = None
1826
+ update_args = ()
1817
1827
# side_effect = parse_side_effect(fdef.body)
1818
1828
side_effect = SideEffect (
1819
1829
new = "new" in name_decorators and not is_immutable (returns ),
1820
- update = update_type ,
1830
+ update = ( update_type , update_args ) ,
1821
1831
points_to_args = "points_to_args" in name_decorators ,
1822
- name = fdef .name ,
1823
1832
)
1824
1833
is_property = "property" in name_decorators
1825
1834
@@ -1946,13 +1955,14 @@ def is_bound_method(t: TypeExpr) -> bool:
1946
1955
1947
1956
1948
1957
def get_side_effect (applied : Overloaded ) -> SideEffect :
1949
- [name ] = {x .side_effect .name for x in applied .items }
1950
1958
return SideEffect (
1951
1959
new = any (x .side_effect .new for x in applied .items ),
1952
- update = join_all (x .side_effect .update for x in applied .items ),
1960
+ update = (
1961
+ join_all (x .side_effect .update [0 ] for x in applied .items ),
1962
+ applied .items [0 ].side_effect .update [1 ],
1963
+ ),
1953
1964
bound_method = any (is_bound_method (x ) for x in applied .items ),
1954
1965
points_to_args = any (x .side_effect .points_to_args for x in applied .items ),
1955
- name = name ,
1956
1966
)
1957
1967
1958
1968
0 commit comments