Skip to content

Commit 7e1a676

Browse files
move tuple-specific check from guard_tracker to vs.tuplevar (apache#6)
2 parents db8f18b + 880de91 commit 7e1a676

File tree

11 files changed

+168
-66
lines changed

11 files changed

+168
-66
lines changed

frontend/guard_tracker.py

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .c_api import get_value_stack_from_top, get_value_stack_size, set_eval_frame, stack_effect
1414
from .instruction import Instruction, ci
1515
from .cache import CachedGraph, get_frame_cache
16-
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInTuple
16+
from .store_pos import StorePos, StoreInStack, StoreInLocal, StoreInGlobal, StoreInAttr, StoreInIndex
1717
from . import variables as vs
1818
from .utils import is_scalar, new_random_key, has_force_graph_break, NullObject, is_call_bytecode, fx_graph_functions, is_user_defined_func, UnknownTypeError, get_all_objects_in_stack
1919
from .object_table import ObjectTable
@@ -142,7 +142,9 @@ def from_frame(cls, frame: FrameType, read_stack: bool,
142142
state.start_stack_size = get_value_stack_size(frame)
143143
for i in range(state.start_stack_size):
144144
value = get_value_stack_from_top(frame, i)
145-
var = vs.make_var_from_value(value, True, state.fx_graph,
145+
var = vs.make_var_from_value(value, True,
146+
state.objects.read_only,
147+
state.fx_graph,
146148
[StoreInLocal(f"__stack__{i}")])
147149
state.objects.add(var, value)
148150
# state.written may be assigned inside make_var_from_value
@@ -320,30 +322,6 @@ def init_state(self, read_stack: bool = True) -> None:
320322
self.state = State.from_frame(self.frame, read_stack, self.frame_root)
321323
self.have_error = False
322324

323-
def variable_check(self, var: TupleVar,
324-
extract_code_at_start: StorePos) -> None:
325-
for i, sub_obj in enumerate(var.value):
326-
sub_var = vs.make_var_from_value(
327-
sub_obj, True, self.state.fx_graph,
328-
[StoreInTuple(extract_code_at_start, i)])
329-
self.state.add_object(sub_var, sub_obj)
330-
if isinstance(sub_var, TupleVar):
331-
self.variable_check(sub_var,
332-
StoreInTuple(extract_code_at_start, i))
333-
334-
def variable_output(self, var: Variable, name_in_graph_fn: str,
335-
store_pos: StorePos, codegen: "GraphFnCodegen") -> None:
336-
if isinstance(var, TupleVar):
337-
self.tuple_output(var)
338-
var.make_output(name_in_graph_fn, store_pos, codegen)
339-
340-
def tuple_output(self, var: TupleVar) -> None:
341-
for sub_val in var.value:
342-
sub_obj = self.state.objects.get(sub_val, allow_unexist_const=True)
343-
var.objs.append(sub_obj)
344-
if isinstance(sub_obj, TupleVar):
345-
self.tuple_output(sub_obj)
346-
347325
def record(
348326
self, frame: FrameType, frame_id: int
349327
) -> None: # pass frame and frame_id only for assertion
@@ -446,8 +424,7 @@ def commit(self, break_before_cur_inst: bool) -> None:
446424

447425
for i, value in enumerate(stack_objs):
448426
var = self.state.objects.get(value, allow_unexist_const=True)
449-
self.variable_output(var, f"__stack__{i}", StoreInStack(i),
450-
graph_codegen)
427+
var.make_output(f"__stack__{i}", StoreInStack(i), graph_codegen)
451428

452429
self.state.fx_graph.set_output_nodes(graph_codegen.get_graph_outputs())
453430

@@ -625,11 +602,11 @@ def LOAD_CONST(self, _inst: Instruction) -> None:
625602
def LOAD_FAST(self, inst: Instruction) -> None:
626603
if inst.argval not in self.state.stored_locals:
627604
obj = self.frame.f_locals[inst.argval]
628-
var = vs.make_var_from_value(obj, True, self.state.fx_graph,
605+
var = vs.make_var_from_value(obj, True,
606+
self.state.objects.read_only,
607+
self.state.fx_graph,
629608
[StoreInLocal(inst.argval)])
630609
self.state.add_object(var, obj)
631-
if isinstance(var, TupleVar):
632-
self.variable_check(var, StoreInLocal(inst.argval))
633610

634611
def LOAD_GLOBAL(self, inst: Instruction) -> None:
635612
if inst.argval not in self.state.stored_globals:
@@ -642,19 +619,20 @@ def LOAD_GLOBAL(self, inst: Instruction) -> None:
642619
except Exception as e:
643620
raise UnknownTypeError(inst.argval)
644621

645-
var = vs.make_var_from_value(obj, True, self.state.fx_graph,
622+
var = vs.make_var_from_value(obj, True,
623+
self.state.objects.read_only,
624+
self.state.fx_graph,
646625
[StoreInGlobal(inst.argval)])
647626
self.state.add_object(var, obj)
648-
if isinstance(var, TupleVar):
649-
self.variable_check(var, StoreInGlobal(inst.argval))
650627

651628
# heheda: we need to make sure that no unbound LOAD_METHOD is called by python runtime to avoid NULL in stack
652629
def LOAD_METHOD(self, inst: Instruction) -> None:
653630
self_obj = get_value_stack_from_top(self.frame, 0)
654631
method = getattr(self_obj, inst.argval)
655632
self_var = self.state.objects.get(self_obj)
656633
method_var = vs.make_var_from_value(
657-
method, self_var.need_guard_check, self.state.fx_graph, [
634+
method, self_var.need_guard_check, self.state.objects.read_only,
635+
self.state.fx_graph, [
658636
StoreInAttr(self_var.extract_code_at_start[0], self_obj,
659637
inst.argval)
660638
] if self_var.need_guard_check else [])
@@ -668,7 +646,8 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
668646
attr = getattr(obj, inst.argval)
669647
obj_var = self.state.objects.get(obj)
670648
attr_var = vs.make_var_from_value(
671-
attr, obj_var.need_guard_check, self.state.fx_graph,
649+
attr, obj_var.need_guard_check, self.state.objects.read_only,
650+
self.state.fx_graph,
672651
[StoreInAttr(obj_var.extract_code_at_start[0], obj, inst.argval)]
673652
if obj_var.need_guard_check else [])
674653
if isinstance(obj_var, vs.ModuleVar):

frontend/object_table.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
1-
from typing import Any, get_args, Optional, Tuple
1+
from typing import Any, get_args, Optional, Tuple, Generic
22
from .variables.base import Variable
33
from .variables import CONST_TYPES, ScalarVar, make_var_from_value
44
from .variables.tuple_ import TupleVar
5-
from .utils import NullObject
5+
from .utils import NullObject, ReadOnlyObject
6+
from .store_pos import StorePos
7+
from .fx_graph import FxGraph
68

79

810
class ObjectTable:
911
objs: dict[int, Variable] # id -> object
1012
# Python caches small integers, so int variables don't have unique ids
1113
objs_no_id: list[Variable]
14+
read_only: 'ReadOnlyObjectTable'
1215

1316
def __init__(self) -> None:
1417
self.objs = {}
1518
self.objs_no_id = []
19+
self.read_only = ReadOnlyObjectTable(self)
1620

1721
def add(self, var: Variable, value: Any) -> None:
1822
if isinstance(value, bool):
@@ -23,9 +27,11 @@ def add(self, var: Variable, value: Any) -> None:
2327
old_var.need_guard_check |= var.need_guard_check
2428
else:
2529
self.objs[id(value)] = var
30+
var.add_subvars_to_table(self)
2631

2732
def add_by_id(self, var: Variable, idx: int) -> None:
2833
self.objs[idx] = var
34+
var.add_subvars_to_table(self)
2935

3036
def get_all(self) -> list[Variable]:
3137
return list(self.objs.values()) + self.objs_no_id
@@ -36,9 +42,9 @@ def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
3642
elif id(value) in self.objs:
3743
return self.objs[id(value)]
3844
elif allow_unexist_const and isinstance(value, get_args(CONST_TYPES)):
39-
return make_var_from_value(value, False)
45+
return make_var_from_value(value, False, self.read_only)
4046
elif isinstance(value, tuple):
41-
return TupleVar(value, False)
47+
return TupleVar(value, False, self.read_only)
4248
raise RuntimeError(f"Object {value} not found in object table")
4349

4450
def get_or_none(self, value: Any) -> Optional[Variable]:
@@ -47,6 +53,19 @@ def get_or_none(self, value: Any) -> Optional[Variable]:
4753
else:
4854
return None
4955

56+
def get_or_make_var(self,
57+
value: Any,
58+
need_guard_check: bool,
59+
fx_graph: Optional[FxGraph] = None,
60+
extract_code_at_start: list[StorePos] = []) -> Variable:
61+
if isinstance(value, bool):
62+
return ScalarVar(value, need_guard_check, extract_code_at_start)
63+
elif id(value) in self.objs:
64+
return self.objs[id(value)]
65+
else:
66+
return make_var_from_value(value, need_guard_check, self.read_only,
67+
fx_graph, extract_code_at_start)
68+
5069
def get_by_id(self, idx: int) -> Variable:
5170
return self.objs[idx]
5271

@@ -55,3 +74,36 @@ def contains(self, value: Any) -> bool:
5574

5675
def contains_by_id(self, idx: int) -> bool:
5776
return idx in self.objs
77+
78+
79+
class ReadOnlyObjectTable:
80+
table: ObjectTable
81+
82+
def __init__(self, table: ObjectTable) -> None:
83+
self.table = table
84+
85+
def get_all(self) -> list[Variable]:
86+
return self.table.get_all()
87+
88+
def get(self, value: Any, allow_unexist_const: bool = False) -> Variable:
89+
return self.table.get(value, allow_unexist_const)
90+
91+
def get_or_none(self, value: Any) -> Optional[Variable]:
92+
return self.table.get_or_none(value)
93+
94+
def get_or_make_var(self,
95+
value: Any,
96+
need_guard_check: bool,
97+
fx_graph: Optional[FxGraph] = None,
98+
extract_code_at_start: list[StorePos] = []) -> Variable:
99+
return self.table.get_or_make_var(value, need_guard_check, fx_graph,
100+
extract_code_at_start)
101+
102+
def get_by_id(self, idx: int) -> Variable:
103+
return self.table.get_by_id(idx)
104+
105+
def contains(self, value: Any) -> bool:
106+
return self.table.contains(value)
107+
108+
def contains_by_id(self, idx: int) -> bool:
109+
return self.table.contains_by_id(idx)

frontend/store_pos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ def __str__(self) -> str:
5050
return f"{self.self_pos}.{self.attr_name}"
5151

5252

53-
class StoreInTuple(StorePos):
53+
class StoreInIndex(StorePos):
5454
self_pos: StorePos
55-
self_idx: int
55+
self_idx: Any
5656

57-
def __init__(self, self_pos: StorePos, self_idx: int) -> None:
57+
def __init__(self, self_pos: StorePos, self_idx: Any) -> None:
5858
self.self_pos = self_pos
5959
self.self_idx = self_idx
6060

frontend/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import dis
3-
from typing import Any, TYPE_CHECKING, Callable
3+
from typing import Any, TYPE_CHECKING, Callable, TypeVar, Generic
44
from types import FrameType
55
import random
66
import operator
@@ -184,3 +184,23 @@ def __enter__(self) -> None:
184184
def __exit__(self, *args: Any) -> None:
185185
if self.old_ld_preload:
186186
os.environ['LD_PRELOAD'] = self.old_ld_preload
187+
188+
189+
T = TypeVar('T')
190+
191+
192+
class ReadOnlyObject(Generic[T]):
193+
obj: T
194+
const_attrs: tuple[str, ...]
195+
196+
def __init__(self, obj: T, const_attrs: tuple[str, ...] = ()) -> None:
197+
self.obj = obj
198+
self.const_attrs = const_attrs
199+
200+
def __getattr__(self, attr: str) -> Any:
201+
if attr in self.const_attrs:
202+
return getattr(self.obj, attr)
203+
else:
204+
raise AttributeError(
205+
f"Attribute {attr} should not be called in reader of {self.obj}"
206+
)

frontend/variables/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union, Optional, Tuple
1+
from typing import Any, Union, Optional, Tuple, TYPE_CHECKING
22
from types import ModuleType
33
import torch
44
from .base import Variable
@@ -10,6 +10,8 @@
1010
from ..fx_graph import FxGraph
1111
from ..utils import NullObject, UnknownTypeError
1212
from ..store_pos import StorePos
13+
if TYPE_CHECKING:
14+
from ..object_table import ReadOnlyObjectTable
1315

1416
ty2var: dict[type[Any], type[Variable]] = {
1517
float: ScalarVar,
@@ -27,23 +29,22 @@
2729

2830
def make_var_from_value(value: Any,
2931
need_guard_check: bool,
32+
object_table: 'ReadOnlyObjectTable',
3033
fx_graph: Optional[FxGraph] = None,
3134
extract_code_at_start: list[StorePos] = []) -> Variable:
3235
if type(value) in ty2var:
33-
return ty2var[type(value)].from_value(value, need_guard_check, fx_graph,
36+
return ty2var[type(value)].from_value(value, need_guard_check,
37+
object_table, fx_graph,
3438
extract_code_at_start)
3539
elif isinstance(value, torch.nn.Module):
36-
return TorchModuleVar.from_value(value, need_guard_check, fx_graph,
37-
extract_code_at_start)
40+
return TorchModuleVar.from_value(value, need_guard_check, object_table,
41+
fx_graph, extract_code_at_start)
3842
elif isinstance(value, ModuleType):
39-
return ModuleVar.from_value(value, need_guard_check, fx_graph,
40-
extract_code_at_start)
43+
return ModuleVar.from_value(value, need_guard_check, object_table,
44+
fx_graph, extract_code_at_start)
4145
elif callable(value):
42-
return FunctionVar.from_value(value, need_guard_check, fx_graph,
43-
extract_code_at_start)
44-
elif isinstance(value, tuple):
45-
return TupleVar.from_value(value, need_guard_check, fx_graph,
46-
extract_code_at_start)
46+
return FunctionVar.from_value(value, need_guard_check, object_table,
47+
fx_graph, extract_code_at_start)
4748
raise UnknownTypeError(type(value))
4849

4950

frontend/variables/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.fx
88
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
99
from ..fx_graph import FxGraph, NodeArgs
10+
from ..object_table import ReadOnlyObjectTable, ObjectTable
1011

1112

1213
@dataclass
@@ -27,6 +28,7 @@ def __init__(self,
2728
def from_value(self,
2829
value: Any,
2930
need_guard_check: bool,
31+
object_table: 'ReadOnlyObjectTable',
3032
fx_graph: Optional[FxGraph] = None,
3133
extract_code_at_start: list[StorePos] = []) -> 'Variable':
3234
raise NotImplementedError
@@ -55,3 +57,6 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
5557
@abstractmethod
5658
def as_fx_node(self) -> "NodeArgs":
5759
raise NotImplementedError
60+
61+
def add_subvars_to_table(self, table: 'ObjectTable') -> None:
62+
pass

frontend/variables/const.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ..store_pos import StorePos
1212
if TYPE_CHECKING:
1313
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
14+
from ..object_table import ReadOnlyObjectTable
1415

1516

1617
class NoneVar(Variable):
@@ -37,6 +38,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
3738
def from_value(cls,
3839
value: None,
3940
need_guard_check: bool,
41+
_object_table: 'ReadOnlyObjectTable',
4042
_fx_graph: Optional[FxGraph] = None,
4143
extract_code_at_start: list[StorePos] = []) -> "NoneVar":
4244
return cls(need_guard_check, extract_code_at_start)
@@ -69,6 +71,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
6971
def from_value(cls,
7072
value: NullObject,
7173
need_guard_check: bool,
74+
_object_table: 'ReadOnlyObjectTable',
7275
_fx_graph: Optional[FxGraph] = None,
7376
extract_code_at_start: list[StorePos] = []) -> "NullVar":
7477
return cls(need_guard_check, extract_code_at_start)
@@ -110,6 +113,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
110113
def from_value(cls,
111114
value: slice,
112115
need_guard_check: bool,
116+
_object_table: 'ReadOnlyObjectTable',
113117
_fx_graph: Optional[FxGraph] = None,
114118
extract_code_at_start: list[StorePos] = []) -> "SliceVar":
115119
return cls(value.start, value.stop, value.step, need_guard_check,
@@ -154,6 +158,7 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
154158
def from_value(cls,
155159
value: ModuleType,
156160
need_guard_check: bool,
161+
_object_table: 'ReadOnlyObjectTable',
157162
_fx_graph: Optional[FxGraph] = None,
158163
extract_code_at_start: list[StorePos] = []) -> "ModuleVar":
159164
if value in torch_modules:
@@ -189,6 +194,7 @@ def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
189194
def from_value(cls,
190195
value: Callable[..., Any],
191196
need_guard_check: bool,
197+
_object_table: 'ReadOnlyObjectTable',
192198
_fx_graph: Optional[FxGraph] = None,
193199
extract_code_at_start: list[StorePos] = []) -> "FunctionVar":
194200
return cls(value, ObjectSrc.USER_DEFINED, need_guard_check,

frontend/variables/scalar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ..store_pos import StorePos
88
if TYPE_CHECKING:
99
from ..pycode_generator import GraphFnCodegen, GuardFnCodegen
10+
from ..object_table import ReadOnlyObjectTable
1011

1112
ScalarType = Union[int, float, bool, str]
1213

@@ -46,6 +47,7 @@ def make_temp(self, name_in_graph_fn: str, store_pos: StorePos,
4647
def from_value(cls,
4748
value: ScalarType,
4849
need_guard_check: bool,
50+
_object_table: 'ReadOnlyObjectTable',
4951
_fx_graph: Optional[FxGraph] = None,
5052
extract_code_at_start: list[StorePos] = []) -> "ScalarVar":
5153
return cls(value, need_guard_check, extract_code_at_start)

0 commit comments

Comments
 (0)