11import dataclasses
22import dis
33import sys
4+ import functools
45from typing import Union , List
6+ from collections import deque
57from .instruction import Instruction
68
79TERMINAL_OPCODES = {
1719 TERMINAL_OPCODES .add (dis .opmap ["JUMP_ABSOLUTE" ])
1820JUMP_OPCODES = set (dis .hasjrel + dis .hasjabs )
1921JUMP_OPNAMES = {dis .opname [opcode ] for opcode in JUMP_OPCODES }
22+ MUST_JUMP_OPCODES = {
23+ dis .opmap ["JUMP_FORWARD" ],
24+ dis .opmap ["JUMP_ABSOLUTE" ],
25+ }
2026HASLOCAL = set (dis .haslocal )
2127HASFREE = set (dis .hasfree )
2228
@@ -43,39 +49,80 @@ class ReadsWrites:
4349def livevars_analysis (instructions : List [Instruction ],
4450 instruction : Instruction ) -> set [str ]:
4551 indexof = get_indexof (instructions )
46- must = ReadsWrites (set (), set (), set ())
47- may = ReadsWrites (set (), set (), set ())
48-
49- def walk (state : ReadsWrites , start : int ) -> None :
50- if start in state .visited :
51- return
52- state .visited .add (start )
53-
54- for i in range (start , len (instructions )):
55- inst = instructions [i ]
56- if inst .opcode in HASLOCAL or inst .opcode in HASFREE :
57- if "LOAD" in inst .opname or "DELETE" in inst .opname :
58- assert isinstance (inst .argval , str )
59- if inst .argval not in must .writes :
60- state .reads .add (inst .argval )
61- elif "STORE" in inst .opname :
62- assert isinstance (inst .argval , str )
63- state .writes .add (inst .argval )
64- elif inst .opname == "MAKE_CELL" :
65- pass
66- else :
67- raise NotImplementedError (f"unhandled { inst .opname } " )
68- # if inst.exn_tab_entry:
69- # walk(may, indexof[inst.exn_tab_entry.target])
70- if inst .opcode in JUMP_OPCODES :
71- assert inst .target is not None
72- walk (may , indexof [inst .target ])
73- state = may
74- if inst .opcode in TERMINAL_OPCODES :
75- return
7652
77- walk (must , indexof [instruction ])
78- return must .reads | may .reads
53+ prev : dict [int , list [int ]] = {}
54+ succ : dict [int , list [int ]] = {}
55+ prev [0 ] = []
56+ for i , inst in enumerate (instructions ):
57+ if inst .opcode not in TERMINAL_OPCODES :
58+ prev [i + 1 ] = [i ]
59+ succ [i ] = [i + 1 ]
60+ else :
61+ prev [i + 1 ] = []
62+ succ [i ] = []
63+ for i , inst in enumerate (instructions ):
64+ if inst .opcode in JUMP_OPCODES :
65+ assert inst .target is not None
66+ target_pc = indexof [inst .target ]
67+ prev [target_pc ].append (i )
68+ succ [i ].append (target_pc )
69+
70+ live_vars : dict [int , frozenset [str ]] = {}
71+
72+ start_pc = indexof [instruction ]
73+ to_visit = deque ([
74+ pc for pc in range (len (instructions ))
75+ if instructions [pc ].opcode in TERMINAL_OPCODES
76+ ])
77+ in_progress : set [int ] = set (to_visit )
78+
79+ def join_fn (a : frozenset [str ], b : frozenset [str ]) -> frozenset [str ]:
80+ return frozenset (a | b )
81+
82+ def gen_fn (
83+ inst : Instruction ,
84+ incoming : frozenset [str ]) -> tuple [frozenset [str ], frozenset [str ]]:
85+ gen = set ()
86+ kill = set ()
87+ if inst .opcode in HASLOCAL or inst .opcode in HASFREE :
88+ if "LOAD" in inst .opname or "DELETE" in inst .opname :
89+ assert isinstance (inst .argval , str )
90+ gen .add (inst .argval )
91+ elif "STORE" in inst .opname :
92+ assert isinstance (inst .argval , str )
93+ kill .add (inst .argval )
94+ elif inst .opname == "MAKE_CELL" :
95+ pass
96+ else :
97+ raise NotImplementedError (f"unhandled { inst .opname } " )
98+
99+ return frozenset (gen ), frozenset (kill )
100+
101+ while len (to_visit ) > 0 :
102+ pc = to_visit .popleft ()
103+ in_progress .remove (pc )
104+ if pc in live_vars :
105+ before = hash (live_vars [pc ])
106+ else :
107+ before = None
108+ succs = [
109+ live_vars [succ_pc ] for succ_pc in succ [pc ] if succ_pc in live_vars
110+ ]
111+ if len (succs ) > 0 :
112+ incoming = functools .reduce (join_fn , succs )
113+ else :
114+ incoming = frozenset ()
115+
116+ gen , kill = gen_fn (instructions [pc ], incoming )
117+
118+ out = (incoming - kill ) | gen
119+ live_vars [pc ] = out
120+ if hash (out ) != before :
121+ for prev_pc in prev [pc ]:
122+ if prev_pc not in in_progress :
123+ to_visit .append (prev_pc )
124+ in_progress .add (prev_pc )
125+ return set (live_vars [start_pc ])
79126
80127
81128stack_effect = dis .stack_effect
@@ -145,3 +192,88 @@ def stacksize_analysis(instructions: List[Instruction]) -> int:
145192 assert low >= 0
146193 assert isinstance (high , int ) # not infinity
147194 return high
195+
196+
197+ def end_of_control_flow (instructions : List [Instruction ], start_pc : int ) -> int :
198+ """
199+ Find the end of the control flow block starting at the given instruction.
200+ """
201+ while instructions [start_pc ].opname == 'EXTENDED_ARG' :
202+ start_pc += 1
203+ assert instructions [start_pc ].opcode in JUMP_OPCODES
204+ assert instructions [start_pc ].target is not None
205+ indexof = get_indexof (instructions )
206+ jump_only_opnames = ['JUMP_FORWARD' , 'JUMP_ABSOLUTE' ]
207+ jump_or_next_opnames = [
208+ 'POP_JUMP_IF_TRUE' , 'POP_JUMP_IF_FALSE' , 'JUMP_IF_NOT_EXC_MATCH' ,
209+ 'JUMP_IF_TRUE_OR_POP' , 'JUMP_IF_FALSE_OR_POP' , 'FOR_ITER'
210+ ]
211+ jump_only_opcodes = [dis .opmap [opname ] for opname in jump_only_opnames ]
212+ jump_or_next_opcodes = [
213+ dis .opmap [opname ] for opname in jump_or_next_opnames
214+ ]
215+ return_value_opcode = dis .opmap ['RETURN_VALUE' ]
216+ possible_end_pcs = set ()
217+ for end_pc , inst in enumerate (instructions ):
218+ if end_pc == start_pc :
219+ continue
220+ inst = instructions [end_pc ]
221+ if not inst .is_jump_target :
222+ continue
223+ visited = set ()
224+ queue = deque ([start_pc ])
225+ reach_end = False
226+ while queue and not reach_end :
227+ pc = queue .popleft ()
228+ inst = instructions [pc ]
229+ targets : list [int ] = []
230+ if inst .target is not None :
231+ if inst .opcode in jump_only_opcodes :
232+ targets = [indexof [inst .target ]]
233+ elif inst .opcode in jump_or_next_opcodes :
234+ targets = [indexof [inst .target ], pc + 1 ]
235+ else :
236+ raise NotImplementedError (f"unhandled { inst .opname } " )
237+ else :
238+ targets = [pc + 1 ]
239+ for target in targets :
240+ if instructions [target ].opcode == return_value_opcode :
241+ reach_end = True
242+ break
243+ if target in visited :
244+ continue
245+ if target == end_pc :
246+ continue
247+ visited .add (target )
248+ queue .append (target )
249+ if not reach_end :
250+ possible_end_pcs .add (end_pc )
251+ visited = set ()
252+ dist : dict [int , int ] = {start_pc : 0 }
253+ queue = deque ([start_pc ])
254+ while queue :
255+ pc = queue .popleft ()
256+ inst = instructions [pc ]
257+ if inst .opcode == return_value_opcode :
258+ continue
259+ targets = []
260+ if inst .target is not None :
261+ if inst .opcode in jump_only_opcodes :
262+ targets = [indexof [inst .target ]]
263+ elif inst .opcode in jump_or_next_opcodes :
264+ targets = [indexof [inst .target ], pc + 1 ]
265+ else :
266+ raise NotImplementedError (f"unhandled { inst .opname } " )
267+ else :
268+ targets = [pc + 1 ]
269+ for target in targets :
270+ if target in visited :
271+ continue
272+ visited .add (target )
273+ dist [target ] = dist [pc ] + 1
274+ queue .append (target )
275+ min_dist = min ([dist [end_pc ] for end_pc in possible_end_pcs ])
276+ for end_pc in possible_end_pcs :
277+ if dist [end_pc ] == min_dist :
278+ return end_pc
279+ return - 1
0 commit comments