Skip to content

Commit cc4ef4d

Browse files
authored
Control flow (apache#13)
2 parents d6f9b69 + a27558a commit cc4ef4d

30 files changed

+1351
-387
lines changed

frontend/bytecode_analysis.py

Lines changed: 164 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import dataclasses
22
import dis
33
import sys
4+
import functools
45
from typing import Union, List
6+
from collections import deque
57
from .instruction import Instruction
68

79
TERMINAL_OPCODES = {
@@ -17,6 +19,10 @@
1719
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
1820
JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
1921
JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
22+
MUST_JUMP_OPCODES = {
23+
dis.opmap["JUMP_FORWARD"],
24+
dis.opmap["JUMP_ABSOLUTE"],
25+
}
2026
HASLOCAL = set(dis.haslocal)
2127
HASFREE = set(dis.hasfree)
2228

@@ -43,39 +49,80 @@ class ReadsWrites:
4349
def livevars_analysis(instructions: List[Instruction],
4450
instruction: Instruction) -> set[str]:
4551
indexof = get_indexof(instructions)
46-
must = ReadsWrites(set(), set(), set())
47-
may = ReadsWrites(set(), set(), set())
48-
49-
def walk(state: ReadsWrites, start: int) -> None:
50-
if start in state.visited:
51-
return
52-
state.visited.add(start)
53-
54-
for i in range(start, len(instructions)):
55-
inst = instructions[i]
56-
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
57-
if "LOAD" in inst.opname or "DELETE" in inst.opname:
58-
assert isinstance(inst.argval, str)
59-
if inst.argval not in must.writes:
60-
state.reads.add(inst.argval)
61-
elif "STORE" in inst.opname:
62-
assert isinstance(inst.argval, str)
63-
state.writes.add(inst.argval)
64-
elif inst.opname == "MAKE_CELL":
65-
pass
66-
else:
67-
raise NotImplementedError(f"unhandled {inst.opname}")
68-
# if inst.exn_tab_entry:
69-
# walk(may, indexof[inst.exn_tab_entry.target])
70-
if inst.opcode in JUMP_OPCODES:
71-
assert inst.target is not None
72-
walk(may, indexof[inst.target])
73-
state = may
74-
if inst.opcode in TERMINAL_OPCODES:
75-
return
7652

77-
walk(must, indexof[instruction])
78-
return must.reads | may.reads
53+
prev: dict[int, list[int]] = {}
54+
succ: dict[int, list[int]] = {}
55+
prev[0] = []
56+
for i, inst in enumerate(instructions):
57+
if inst.opcode not in TERMINAL_OPCODES:
58+
prev[i + 1] = [i]
59+
succ[i] = [i + 1]
60+
else:
61+
prev[i + 1] = []
62+
succ[i] = []
63+
for i, inst in enumerate(instructions):
64+
if inst.opcode in JUMP_OPCODES:
65+
assert inst.target is not None
66+
target_pc = indexof[inst.target]
67+
prev[target_pc].append(i)
68+
succ[i].append(target_pc)
69+
70+
live_vars: dict[int, frozenset[str]] = {}
71+
72+
start_pc = indexof[instruction]
73+
to_visit = deque([
74+
pc for pc in range(len(instructions))
75+
if instructions[pc].opcode in TERMINAL_OPCODES
76+
])
77+
in_progress: set[int] = set(to_visit)
78+
79+
def join_fn(a: frozenset[str], b: frozenset[str]) -> frozenset[str]:
80+
return frozenset(a | b)
81+
82+
def gen_fn(
83+
inst: Instruction,
84+
incoming: frozenset[str]) -> tuple[frozenset[str], frozenset[str]]:
85+
gen = set()
86+
kill = set()
87+
if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
88+
if "LOAD" in inst.opname or "DELETE" in inst.opname:
89+
assert isinstance(inst.argval, str)
90+
gen.add(inst.argval)
91+
elif "STORE" in inst.opname:
92+
assert isinstance(inst.argval, str)
93+
kill.add(inst.argval)
94+
elif inst.opname == "MAKE_CELL":
95+
pass
96+
else:
97+
raise NotImplementedError(f"unhandled {inst.opname}")
98+
99+
return frozenset(gen), frozenset(kill)
100+
101+
while len(to_visit) > 0:
102+
pc = to_visit.popleft()
103+
in_progress.remove(pc)
104+
if pc in live_vars:
105+
before = hash(live_vars[pc])
106+
else:
107+
before = None
108+
succs = [
109+
live_vars[succ_pc] for succ_pc in succ[pc] if succ_pc in live_vars
110+
]
111+
if len(succs) > 0:
112+
incoming = functools.reduce(join_fn, succs)
113+
else:
114+
incoming = frozenset()
115+
116+
gen, kill = gen_fn(instructions[pc], incoming)
117+
118+
out = (incoming - kill) | gen
119+
live_vars[pc] = out
120+
if hash(out) != before:
121+
for prev_pc in prev[pc]:
122+
if prev_pc not in in_progress:
123+
to_visit.append(prev_pc)
124+
in_progress.add(prev_pc)
125+
return set(live_vars[start_pc])
79126

80127

81128
stack_effect = dis.stack_effect
@@ -145,3 +192,88 @@ def stacksize_analysis(instructions: List[Instruction]) -> int:
145192
assert low >= 0
146193
assert isinstance(high, int) # not infinity
147194
return high
195+
196+
197+
def end_of_control_flow(instructions: List[Instruction], start_pc: int) -> int:
198+
"""
199+
Find the end of the control flow block starting at the given instruction.
200+
"""
201+
while instructions[start_pc].opname == 'EXTENDED_ARG':
202+
start_pc += 1
203+
assert instructions[start_pc].opcode in JUMP_OPCODES
204+
assert instructions[start_pc].target is not None
205+
indexof = get_indexof(instructions)
206+
jump_only_opnames = ['JUMP_FORWARD', 'JUMP_ABSOLUTE']
207+
jump_or_next_opnames = [
208+
'POP_JUMP_IF_TRUE', 'POP_JUMP_IF_FALSE', 'JUMP_IF_NOT_EXC_MATCH',
209+
'JUMP_IF_TRUE_OR_POP', 'JUMP_IF_FALSE_OR_POP', 'FOR_ITER'
210+
]
211+
jump_only_opcodes = [dis.opmap[opname] for opname in jump_only_opnames]
212+
jump_or_next_opcodes = [
213+
dis.opmap[opname] for opname in jump_or_next_opnames
214+
]
215+
return_value_opcode = dis.opmap['RETURN_VALUE']
216+
possible_end_pcs = set()
217+
for end_pc, inst in enumerate(instructions):
218+
if end_pc == start_pc:
219+
continue
220+
inst = instructions[end_pc]
221+
if not inst.is_jump_target:
222+
continue
223+
visited = set()
224+
queue = deque([start_pc])
225+
reach_end = False
226+
while queue and not reach_end:
227+
pc = queue.popleft()
228+
inst = instructions[pc]
229+
targets: list[int] = []
230+
if inst.target is not None:
231+
if inst.opcode in jump_only_opcodes:
232+
targets = [indexof[inst.target]]
233+
elif inst.opcode in jump_or_next_opcodes:
234+
targets = [indexof[inst.target], pc + 1]
235+
else:
236+
raise NotImplementedError(f"unhandled {inst.opname}")
237+
else:
238+
targets = [pc + 1]
239+
for target in targets:
240+
if instructions[target].opcode == return_value_opcode:
241+
reach_end = True
242+
break
243+
if target in visited:
244+
continue
245+
if target == end_pc:
246+
continue
247+
visited.add(target)
248+
queue.append(target)
249+
if not reach_end:
250+
possible_end_pcs.add(end_pc)
251+
visited = set()
252+
dist: dict[int, int] = {start_pc: 0}
253+
queue = deque([start_pc])
254+
while queue:
255+
pc = queue.popleft()
256+
inst = instructions[pc]
257+
if inst.opcode == return_value_opcode:
258+
continue
259+
targets = []
260+
if inst.target is not None:
261+
if inst.opcode in jump_only_opcodes:
262+
targets = [indexof[inst.target]]
263+
elif inst.opcode in jump_or_next_opcodes:
264+
targets = [indexof[inst.target], pc + 1]
265+
else:
266+
raise NotImplementedError(f"unhandled {inst.opname}")
267+
else:
268+
targets = [pc + 1]
269+
for target in targets:
270+
if target in visited:
271+
continue
272+
visited.add(target)
273+
dist[target] = dist[pc] + 1
274+
queue.append(target)
275+
min_dist = min([dist[end_pc] for end_pc in possible_end_pcs])
276+
for end_pc in possible_end_pcs:
277+
if dist[end_pc] == min_dist:
278+
return end_pc
279+
return -1

frontend/c_api.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,12 @@ def get_code_map(frame: FrameType) -> 'ProcessedCode':
7272

7373

7474
def is_bound_method(obj: Any, name: str) -> bool:
75+
pass
76+
77+
78+
def parse_rangeiterobject(obj: Any) -> Tuple[int, int, int, int]:
79+
pass
80+
81+
82+
def make_rangeiterobject(start: int, stop: int, step: int) -> Any:
7583
pass

frontend/code.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ def get_inst(self, lasti: int) -> Instruction:
145145
def get_pc_by_inst(self, inst: Instruction) -> int:
146146
return self.guarded_pc[inst]
147147

148+
def is_match(self, original_pc: int, guard_pc: int) -> bool:
149+
return self.pc_guarded_to_origin[guard_pc] == original_pc
150+
148151
def get_dependence_of_stack_var(self, original_inst: Instruction,
149152
stack_depth: int) -> list[Instruction]:
150153
raise NotImplementedError

frontend/compile.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,6 @@ def reset() -> None:
8181
from . import utils
8282
utils.reset()
8383
from . import fx_graph
84-
fx_graph.reset()
84+
fx_graph.reset()
85+
from . import dynamic
86+
dynamic.reset()

frontend/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
from typing import Callable, Any, Union
22

3-
backend: Union[str, Callable[..., Any]] = "inductor"
3+
CONFIG = {
4+
"backend": "inductor", # Union[str, Callable[..., Any]]
5+
}
6+
7+
8+
def set_config(key: str, value: Any) -> None:
9+
CONFIG[key] = value
10+
11+
12+
def get_config(key: str) -> Any:
13+
return CONFIG[key]

frontend/control_flow.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import dataclasses
2+
from typing import Any, Optional
3+
import torch
4+
from .store_pos import StorePos
5+
6+
7+
@dataclasses.dataclass
8+
class LoopPosMap:
9+
input_only_pos: list[tuple[str, StorePos]]
10+
joint_pos: list[tuple[str, StorePos]]
11+
output_only_pos: list[tuple[str, StorePos]]
12+
13+
14+
class LoopModule(torch.nn.Module): #type: ignore
15+
body: torch.fx.GraphModule
16+
num_read_only_param: int
17+
num_iter: int
18+
19+
def __init__(self, body: torch.fx.GraphModule, num_read_only_param: int,
20+
num_iter: int):
21+
super(LoopModule, self).__init__()
22+
self.body = body
23+
self.num_read_only_param = num_read_only_param
24+
self.num_iter = num_iter
25+
26+
# def forward(self, num_iter: Optional[int], cond: torch.Tensor, *values:
27+
# Any) -> Any:
28+
def forward(self, *values: Any) -> Any:
29+
iter_num = 0
30+
# assert cond.dtype == torch.bool
31+
read_only = values[:self.num_read_only_param]
32+
loop_carry = values[self.num_read_only_param:]
33+
while iter_num < self.num_iter:
34+
# and cond.item():
35+
loop_carry = self.body(iter_num, *read_only, *loop_carry)
36+
# cond, *loop_carry = self.body(iter_num, cond, *read_only,
37+
# *loop_carry)
38+
iter_num += 1
39+
return loop_carry
40+
41+
42+
class ControlFlowInfo:
43+
start_pc: int
44+
end_pc: int
45+
46+
def __init__(self, start_pc: int, end_pc: int) -> None:
47+
self.start_pc = start_pc
48+
self.end_pc = end_pc
49+
50+
51+
class ForLoopInfo(ControlFlowInfo):
52+
num_iter: int
53+
cur_iter: int
54+
pos_map: Optional[LoopPosMap]
55+
inner_graph: Optional[torch.fx.Graph]
56+
57+
def __init__(self, start_pc: int, end_pc: int, num_iter: int) -> None:
58+
super().__init__(start_pc, end_pc)
59+
self.num_iter = num_iter
60+
self.cur_iter = 0

frontend/csrc/csrc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,7 @@ struct StackEffect {
2727
bool local_effect, global_effect;
2828
};
2929
StackEffect stack_effect(int opcode, int oparg, int jump);
30+
PyObject *parse_rangeiterobject(PyObject *self, PyObject *args);
31+
PyObject *make_rangeiterobject(PyObject *self, PyObject *args);
3032

3133
} // namespace frontend_csrc

frontend/csrc/frame_evaluation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,10 @@ static PyMethodDef _methods[] = {
491491
METH_VARARGS, NULL},
492492
{"get_code_map", get_code_map, METH_VARARGS, NULL},
493493
{"is_bound_method", is_bound_method, METH_VARARGS, NULL},
494+
{"parse_rangeiterobject", frontend_csrc::parse_rangeiterobject,
495+
METH_VARARGS, NULL},
496+
{"make_rangeiterobject", frontend_csrc::make_rangeiterobject, METH_VARARGS,
497+
NULL},
494498
{NULL, NULL, 0, NULL}};
495499

496500
static struct PyModuleDef _module = {

frontend/csrc/parse_types.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <Python.h>
2+
#include <object.h>
3+
4+
namespace frontend_csrc {
5+
6+
typedef struct {
7+
PyObject_HEAD long index;
8+
long start;
9+
long step;
10+
long len;
11+
} rangeiterobject;
12+
13+
PyObject *parse_rangeiterobject(PyObject *self, PyObject *args) {
14+
PyObject *obj;
15+
if (!PyArg_ParseTuple(args, "O", &obj)) {
16+
return NULL;
17+
}
18+
if (Py_TYPE(obj) != &PyRangeIter_Type) {
19+
PyErr_SetString(PyExc_TypeError, "Expected rangeiterobject");
20+
return NULL;
21+
}
22+
rangeiterobject *robj = (rangeiterobject *)obj;
23+
return PyTuple_Pack(
24+
4, PyLong_FromLong(robj->index), PyLong_FromLong(robj->start),
25+
PyLong_FromLong(robj->step), PyLong_FromLong(robj->len));
26+
}
27+
28+
PyObject *make_rangeiterobject(PyObject *self, PyObject *args) {
29+
long index, start, step, len;
30+
if (!PyArg_ParseTuple(args, "llll", &index, &start, &step, &len)) {
31+
return NULL;
32+
}
33+
rangeiterobject *robj = PyObject_New(rangeiterobject, &PyRangeIter_Type);
34+
robj->index = index;
35+
robj->start = start;
36+
robj->step = step;
37+
robj->len = len;
38+
return (PyObject *)robj;
39+
}
40+
41+
} // namespace frontend_csrc

0 commit comments

Comments
 (0)