Skip to content

Commit 2ce3c8f

Browse files
committed
use update(T, arg1, arg2, ...) for points-to effect
also update(T, *arg) for copy points to Signed-off-by: Elazar Gershuni <elazarg@gmail.com>
1 parent d4d3c05 commit 2ce3c8f

File tree

4 files changed

+106
-76
lines changed

4 files changed

+106
-76
lines changed

pythia/dom_typed_pointer.py

+35-10
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,10 @@ def expr(
682682
func_type = predefined(var)
683683
else:
684684
assert False, f"Expected Var or PredefinedFunction, got {var}"
685-
if isinstance(
686-
func_type, ts.Instantiation
687-
) and func_type.generic == ts.Ref("builtins.type"):
685+
if (
686+
isinstance(func_type, ts.Instantiation)
687+
and func_type.generic == ts.TYPE
688+
):
688689
func_type = ts.get_init_func(func_type)
689690
assert isinstance(
690691
func_type, ts.Overloaded
@@ -703,7 +704,7 @@ def expr(
703704

704705
side_effect = ts.get_side_effect(applied)
705706
dirty = make_dirty()
706-
if side_effect.update is not None:
707+
if side_effect.update[0] is not None:
707708
func_obj = pythia.dom_concrete.Set[Object].squeeze(func_objects)
708709
if isinstance(func_obj, pythia.dom_concrete.Set):
709710
raise RuntimeError(
@@ -732,15 +733,39 @@ def expr(
732733
]
733734
# Expected two objects: self argument and locals
734735

735-
if new_tp.types[self_obj] != side_effect.update:
736+
if True or new_tp.types[self_obj] != side_effect.update[0]:
736737
if monomorophized:
737738
raise RuntimeError(
738739
f"Update with aliased objects: {aliasing_pointers} (not: {func_obj, LOCALS})"
739740
)
740-
new_tp.types[self_obj] = side_effect.update
741-
if side_effect.name == "append":
742-
x = arg_objects[0]
743-
new_tp.pointers.update(self_obj, tac.Var("*"), x)
741+
new_tp.types[self_obj] = side_effect.update[0]
742+
arg_indices_to_point = side_effect.update[1]
743+
if arg_indices_to_point:
744+
for i in arg_indices_to_point:
745+
starred = False
746+
if isinstance(i, ts.Star):
747+
assert len(i.items) == 1
748+
i = i.items[0]
749+
starred = True
750+
751+
if isinstance(i, ts.Literal) and isinstance(
752+
i.value, int
753+
):
754+
# TODO: minus one only for self. Should be fixed on binding
755+
v = i.value - 1
756+
assert v < len(
757+
arg_objects
758+
), f"{v} >= {len(arg_objects)}"
759+
targets = arg_objects[v]
760+
if starred:
761+
targets = prev_tp.pointers[
762+
targets, tac.Var("*")
763+
]
764+
new_tp.pointers.update(
765+
self_obj, tac.Var("*"), targets
766+
)
767+
else:
768+
assert False, i
744769

745770
t = ts.get_return(applied)
746771
assert t != ts.BOTTOM, f"Expected non-bottom return type for {locals()}"
@@ -792,7 +817,7 @@ def expr(
792817
assert isinstance(applied, ts.Overloaded)
793818
side_effect = ts.get_side_effect(applied)
794819
dirty = make_dirty()
795-
if side_effect.update is not None:
820+
if side_effect.update[0] is not None:
796821
dirty = make_dirty_from_keys(
797822
value_objects, pythia.dom_concrete.Set[tac.Var].top()
798823
)

pythia/type_system.py

+56-46
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ def __repr__(self) -> str:
2020
return self.name
2121

2222

23+
ANY = Ref("typing.Any")
24+
LIST = Ref("builtins.list")
25+
TUPLE = Ref("builtins.tuple")
26+
SET = Ref("builtins.set")
27+
TYPE = Ref("builtins.type")
28+
NONE_TYPE = Ref("builtins.NoneType")
29+
BOOL = Ref("builtins.bool")
30+
INT = Ref("builtins.int")
31+
FLOAT = Ref("builtins.float")
32+
STR = Ref("builtins.str")
33+
34+
2335
@dataclass(frozen=True, slots=True)
2436
class TypeVar:
2537
name: str
@@ -132,25 +144,28 @@ def literal(value: int | str | bool | float | tuple | list | None) -> Literal:
132144
case value if value is NULL:
133145
ref = Ref("builtins.ellipsis")
134146
case int():
135-
ref = Ref("builtins.int")
147+
ref = INT
136148
case float():
137-
ref = Ref("builtins.float")
149+
ref = FLOAT
138150
case str():
139-
ref = Ref("builtins.str")
151+
ref = STR
140152
case bool():
141-
ref = Ref("builtins.bool")
153+
ref = BOOL
142154
case None:
143-
ref = Ref("builtins.NoneType")
155+
ref = NONE_TYPE
144156
case tuple():
145-
ref = Ref("builtins.tuple")
157+
ref = TUPLE
146158
case list():
147159
value = tuple(value)
148-
ref = Ref("builtins.list")
160+
ref = LIST
149161
case _:
150162
assert False, f"Unknown literal type {value!r}"
151163
return Literal(value, ref)
152164

153165

166+
NONE = literal(None)
167+
168+
154169
@dataclass(frozen=True, slots=True)
155170
class TypedDict:
156171
items: frozenset[Row]
@@ -227,9 +242,8 @@ def __repr__(self) -> str:
227242
class SideEffect:
228243
new: bool
229244
bound_method: bool = False
230-
update: typing.Optional[TypeExpr] = None
245+
update: tuple[typing.Optional[TypeExpr], tuple[int, ...]] = (None, ())
231246
points_to_args: bool = False
232-
name: typing.Optional[str] = None # ad hoc effects
233247

234248

235249
@dataclass(frozen=True, slots=True)
@@ -248,7 +262,7 @@ def __repr__(self) -> str:
248262
new = "new " if self.new() else ""
249263
update = (
250264
"{update " + str(self.side_effect.update) + "}@"
251-
if self.side_effect.update
265+
if self.side_effect.update[0]
252266
else ""
253267
)
254268
return f"[{type_params}]({self.params} -> {update}{new}{self.return_type})"
@@ -340,9 +354,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
340354
case Row() as row:
341355
return replace(row, type=bind_typevars(row.type, context))
342356
case SideEffect() as s:
343-
if s.update is None:
357+
if s.update[0] is None:
344358
return s
345-
return replace(s, update=bind_typevars(s.update, context))
359+
return replace(s, update=(bind_typevars(s.update[0], context), s.update[1]))
346360
case Class(
347361
type_params=type_params, class_dict=class_dict, inherits=inherits
348362
) as klass:
@@ -372,9 +386,9 @@ def bind_typevars(t: TypeExpr, context: dict[TypeVar, TypeExpr]) -> TypeExpr:
372386
return choices.items[actual_arg.value]
373387
return Access(choices, actual_arg)
374388
case SideEffect() as s:
375-
if s.update is None:
389+
if s.update[0] is None:
376390
return s
377-
return replace(s, update=bind_typevars(s.update, context))
391+
return replace(s, update=(bind_typevars(s.update[0], context), s.update[1]))
378392
raise NotImplementedError(f"{t!r}, {type(t)}")
379393

380394

@@ -414,7 +428,6 @@ def union(items: typing.Iterable[TypeExpr], squeeze=True) -> TypeExpr:
414428

415429
TOP = typed_dict([])
416430
BOTTOM = Union(frozenset())
417-
ANY = Ref("typing.Any")
418431

419432

420433
def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
@@ -461,7 +474,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
461474
tuple(join(t1, t2) for t1, t2 in zip(l1.value, l2.value)),
462475
l1.ref,
463476
)
464-
if l1.ref == Ref("builtins.list"):
477+
if l1.ref == LIST:
465478
return Instantiation(
466479
l1.ref, (join_all([*l1.value, *l2.value]),)
467480
)
@@ -477,7 +490,7 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
477490
Instantiation() as inst,
478491
Literal(tuple() as value, ref=ref),
479492
) if ref == inst.generic:
480-
if ref.name == "builtins.list":
493+
if ref == LIST:
481494
value = (join_all(value),)
482495
return join(inst, Instantiation(ref, value))
483496
case (TypedDict(items1), TypedDict(items2)): # type: ignore
@@ -497,21 +510,17 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
497510
if index1 == index2:
498511
return Row(index1, join(t1, t2))
499512
return BOTTOM
500-
case (Row(_, t1), (Instantiation(Ref("builtins.list"), type_args))) | (
501-
Instantiation(Ref("builtins.list"), type_args),
513+
case (Row(_, t1), (Instantiation(Ref("builtings.list"), type_args))) | (
514+
Instantiation(Ref("builtings.list"), type_args),
502515
Row(_, t1),
503516
):
504-
return Instantiation(
505-
Ref("builtins.list"), tuple(join(t1, t) for t in type_args)
506-
)
517+
return Instantiation(LIST, tuple(join(t1, t) for t in type_args))
507518
case (Row(_, t1), (Instantiation(Ref("builtins.tuple"), type_args))) | (
508519
Instantiation(Ref("builtins.tuple"), type_args),
509520
Row(_, t1),
510521
):
511522
# not exact; should only join at the index of the row
512-
return Instantiation(
513-
Ref("builtins.tuple"), tuple(join(t1, t) for t in type_args)
514-
)
523+
return Instantiation(TUPLE, tuple(join(t1, t) for t in type_args))
515524
case Class(), Class():
516525
return TOP
517526
case (Class(name="int") | Ref("builtins.int") as c, Literal(int())) | (
@@ -520,10 +529,11 @@ def join(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
520529
):
521530
return c
522531
case (SideEffect() as s1, SideEffect() as s2):
532+
assert s1.update[1] == s2.update[1]
523533
return SideEffect(
524534
new=s1.new | s2.new,
525535
bound_method=s1.bound_method | s2.bound_method,
526-
update=join(s1.update, s2.update),
536+
update=(join(s1.update[0], s2.update[0]), s1.update[1]),
527537
points_to_args=s1.points_to_args | s2.points_to_args,
528538
)
529539
case x, y:
@@ -556,7 +566,7 @@ def meet(t1: TypeExpr, t2: TypeExpr) -> TypeExpr:
556566
tuple(meet(t1, t2) for t1, t2 in zip(l1.value, l2.value)),
557567
l1.ref,
558568
)
559-
if l1.ref == Ref("builtins.list"):
569+
if l1.ref == LIST:
560570
return Instantiation(
561571
l1.ref, (meet_all([*l1.value, *l2.value]),)
562572
)
@@ -1054,13 +1064,13 @@ def get_init_func(callable: TypeExpr) -> TypeExpr:
10541064
]
10551065
)
10561066
side_effect = SideEffect(
1057-
new=True, bound_method=True, update=None, points_to_args=True
1067+
new=True, bound_method=True, update=(None, ()), points_to_args=True
10581068
)
10591069
res = bind_self(
10601070
overload(
10611071
replace(
10621072
f,
1063-
return_type=f.side_effect.update or selftype,
1073+
return_type=f.side_effect.update[0] or selftype,
10641074
side_effect=side_effect,
10651075
)
10661076
for f in init.items
@@ -1268,7 +1278,7 @@ def make_list_constructor() -> Overloaded:
12681278
FunctionType(
12691279
params=typed_dict([make_row(0, "args", args)]),
12701280
return_type=return_type,
1271-
side_effect=SideEffect(new=True, points_to_args=True, name="[]"),
1281+
side_effect=SideEffect(new=True, points_to_args=True),
12721282
is_property=False,
12731283
type_params=(args,),
12741284
)
@@ -1277,13 +1287,13 @@ def make_list_constructor() -> Overloaded:
12771287

12781288

12791289
def make_set_constructor() -> Overloaded:
1280-
return_type = Instantiation(Ref("builtins.set"), (union([]),))
1290+
return_type = Instantiation(SET, (union([]),))
12811291
return overload(
12821292
[
12831293
FunctionType(
12841294
params=typed_dict([]),
12851295
return_type=return_type,
1286-
side_effect=SideEffect(new=True, points_to_args=True, name="{}"),
1296+
side_effect=SideEffect(new=True, points_to_args=True),
12871297
is_property=False,
12881298
type_params=(),
12891299
)
@@ -1293,13 +1303,13 @@ def make_set_constructor() -> Overloaded:
12931303

12941304
def make_tuple_constructor() -> Overloaded:
12951305
args = TypeVar("Args", is_args=True)
1296-
return_type = Instantiation(Ref("builtins.tuple"), (args,))
1306+
return_type = Instantiation(TUPLE, (args,))
12971307
return overload(
12981308
[
12991309
FunctionType(
13001310
params=typed_dict([make_row(0, "args", args)]),
13011311
return_type=return_type,
1302-
side_effect=SideEffect(new=True, points_to_args=True, name="()"),
1312+
side_effect=SideEffect(new=True, points_to_args=True),
13031313
is_property=False,
13041314
type_params=(args,),
13051315
)
@@ -1309,8 +1319,6 @@ def make_tuple_constructor() -> Overloaded:
13091319

13101320
def make_slice_constructor() -> Overloaded:
13111321
return_type = Ref("builtins.slice")
1312-
NONE = literal(None)
1313-
INT = Ref("builtins.int")
13141322
both = union([NONE, INT])
13151323
return overload(
13161324
[
@@ -1319,7 +1327,7 @@ def make_slice_constructor() -> Overloaded:
13191327
[make_row(0, "start", both), make_row(1, "end", both)]
13201328
),
13211329
return_type=return_type,
1322-
side_effect=SideEffect(new=True, name="[:]"),
1330+
side_effect=SideEffect(new=True),
13231331
is_property=False,
13241332
type_params=(),
13251333
)
@@ -1619,8 +1627,9 @@ def visit_Name(self, name) -> TypeExpr:
16191627
return self.symtable.lookup(name.id)
16201628

16211629
def visit_Starred(self, starred: ast.Starred) -> TypeExpr:
1622-
assert isinstance(starred.value, ast.Name), f"{starred!r}"
1623-
return TypeVar(starred.value.id, is_args=True)
1630+
if isinstance(starred.value, ast.Name):
1631+
return TypeVar(starred.value.id, is_args=True)
1632+
return Star((self.to_type(starred.value),))
16241633

16251634
def visit_Subscript(self, subscr: ast.Subscript) -> TypeExpr:
16261635
generic = self.to_type(subscr.value)
@@ -1677,7 +1686,7 @@ def is_immutable(value: TypeExpr) -> bool:
16771686
case Row(type=value):
16781687
return is_immutable(value)
16791688
case FunctionType() as f:
1680-
if f.side_effect.update is not None:
1689+
if f.side_effect.update[0] is not None:
16811690
return False
16821691
return True
16831692
case Ref(name):
@@ -1805,21 +1814,21 @@ def visit_FunctionDef(self, fdef: ast.FunctionDef) -> FunctionType:
18051814
update = call_decorators.get("update")
18061815
if update is not None:
18071816
assert isinstance(update, ast.Call)
1808-
assert len(update.args) == 1
18091817
update_arg = update.args[0]
18101818
if isinstance(update_arg, ast.Constant) and isinstance(
18111819
update_arg.value, str
18121820
):
18131821
update_arg = ast.parse(update_arg.s).body[0].value
18141822
update_type = self.expr_to_type(update_arg)
1823+
update_args = tuple(self.expr_to_type(x) for x in update.args[1:])
18151824
else:
18161825
update_type = None
1826+
update_args = ()
18171827
# side_effect = parse_side_effect(fdef.body)
18181828
side_effect = SideEffect(
18191829
new="new" in name_decorators and not is_immutable(returns),
1820-
update=update_type,
1830+
update=(update_type, update_args),
18211831
points_to_args="points_to_args" in name_decorators,
1822-
name=fdef.name,
18231832
)
18241833
is_property = "property" in name_decorators
18251834

@@ -1946,13 +1955,14 @@ def is_bound_method(t: TypeExpr) -> bool:
19461955

19471956

19481957
def get_side_effect(applied: Overloaded) -> SideEffect:
1949-
[name] = {x.side_effect.name for x in applied.items}
19501958
return SideEffect(
19511959
new=any(x.side_effect.new for x in applied.items),
1952-
update=join_all(x.side_effect.update for x in applied.items),
1960+
update=(
1961+
join_all(x.side_effect.update[0] for x in applied.items),
1962+
applied.items[0].side_effect.update[1],
1963+
),
19531964
bound_method=any(is_bound_method(x) for x in applied.items),
19541965
points_to_args=any(x.side_effect.points_to_args for x in applied.items),
1955-
name=name,
19561966
)
19571967

19581968

0 commit comments

Comments
 (0)