Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

add a basic BuiltinVariable dispatch mechanism #157

Merged
merged 15 commits into from
Jun 15, 2023
Prev Previous commit
Next Next commit
add a basic dispatcher mechanism
  • Loading branch information
SigureMo committed Jun 13, 2023
commit bd7c86bdd6004c86d7cca711796f700aee54323b
74 changes: 70 additions & 4 deletions sot/opcode_translator/executor/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,79 @@
from __future__ import annotations

from typing import Any
from typing import Any, Callable

from sot.utils import Singleton

class Handler:
def __init__(self, *types: type[Any]):
from .variables import DictVariable, VariableBase


class Pattern:
def __init__(self, *types: type[Any], **kwtypes: type[Any]):
self.types = types
self.kwtypes = kwtypes

def match_types(self, *args: *Any) -> bool:
def match_inputs(self, *args: Any, **kwargs: Any) -> bool:
if len(args) != len(self.types):
return False
if any(name not in kwargs for name in self.kwtypes.keys()):
return False
return all(
isinstance(arg, type_) for arg, type_ in zip(args, self.types)
) and all(
isinstance(kwargs[name], type_)
for name, type_ in self.kwtypes.items()
)

def __repr__(self) -> str:
types_repr = ", ".join([type_.__name__ for type_ in self.types])
kwtypes_repr = ", ".join(
[f"{name}={type_.__name__}" for name, type_ in self.kwtypes.items()]
)
return f"Pattern({types_repr}, {kwtypes_repr})"


@Singleton
class Dispatcher:
handlers: dict[Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]]]

def __init__(self):
self.handlers = {}
self.register(
dict.keys,
(DictVariable,),
{},
lambda var: var.override_method_keys(),
)
self.register(
dict.update,
(DictVariable, DictVariable),
{},
lambda var, other: var.override_method_update(other),
)
self.register(
getattr,
(VariableBase, str),
{},
lambda var, name: var.getattr(name),
)

def register(
self,
fn: Callable[..., Any],
types: tuple[type[Any], ...],
kwtypes: dict[str, type[Any]],
handler: Callable[..., Any],
):
if fn not in self.handlers:
self.handlers[fn] = []
self.handlers[fn].append((Pattern(*types, **kwtypes), handler))

def dispatch(
self, fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> Callable[..., Any] | None:
if fn not in self.handlers:
return None
for pattern, handler in self.handlers[fn]:
if pattern.match_inputs(*args, **kwargs):
return handler
return None
26 changes: 21 additions & 5 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
from .tracker import (
BuiltinTracker,
ConstTracker,
DanglingTracker,
DummyTracker,
GetItemTracker,
GetIterTracker,
GlobalTracker,
LocalTracker,
)
from .variables import (
BuiltinVariable,
CallableVariable,
ConstantVariable,
ContainerVariable,
Expand Down Expand Up @@ -413,7 +415,11 @@ def NOP(self, instr):
def LOAD_ATTR(self, instr):
attr_name = instr.argval
obj = self.pop()
self.push(getattr(obj, attr_name))
self.push(
BuiltinVariable(
getattr, graph=self._graph, tracker=DanglingTracker()
)(obj, attr_name)
)

def LOAD_CONST(self, instr):
var = self._co_consts[instr.arg]
Expand All @@ -435,7 +441,9 @@ def LOAD_GLOBAL(self, instr):
def LOAD_METHOD(self, instr):
method_name = instr.argval
obj = self.pop()
method = getattr(obj, method_name)
method = BuiltinVariable(
getattr, graph=self._graph, tracker=DanglingTracker()
)(obj, method_name)
if isinstance(method, MethodVariable):
# bound method, push the unbound method and the self
self.push(method.fn)
Expand Down Expand Up @@ -958,7 +966,9 @@ def FORMAT_VALUE(self, instr):
def DICT_UPDATE(self, instr):
dict_value = self.pop()
assert instr.argval > 0
self._stack[-instr.arg].update(dict_value)
BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())(
self._stack[-instr.arg], dict_value
)

def DICT_MERGE(self, instr):
dict_value = self.pop()
Expand All @@ -969,7 +979,9 @@ def DICT_MERGE(self, instr):
raise InnerError(
f"got multiple values for keyword argument '{key}'"
)
self._stack[-instr.arg].update(dict_value)
BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())(
self._stack[-instr.arg], dict_value
)

def LIST_EXTEND(self, instr):
list_value = self.pop()
Expand Down Expand Up @@ -1292,7 +1304,11 @@ def _inline_call_for_loop(self, iterator, for_iter):
fn, inputs = pycode_gen.gen_for_loop_fn_between(
iterator, self.indexof(for_iter), self.indexof(for_iter.jump_to)
)
fn = UserDefinedFunctionVariable(fn, self._graph, DummyTracker([]))
fn = UserDefinedFunctionVariable(
fn,
self._graph,
DanglingTracker(),
)
input_vars = [self._locals[name] for name in inputs[:-1]] + [iterator]
ret = fn(*input_vars)
for name, val in zip(inputs[:-1], ret[:-1]):
Expand Down
17 changes: 17 additions & 0 deletions sot/opcode_translator/executor/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ def __repr__(self) -> str:
return f"DummyTracker(num_inputs={len(self.inputs)})"


class DanglingTracker(Tracker):
def __init__(self):
super().__init__([])

def gen_instructions(self, codegen: PyCodeGen):
raise InnerError("DanglingTracker has no instructions")

def trace_value_from_frame(self):
raise InnerError("DanglingTracker can't trace value from frame")

def is_traceable(self):
return False

def __repr__(self) -> str:
return "DanglingTracker()"


class LocalTracker(Tracker):
def __init__(self, name: str):
super().__init__([])
Expand Down
2 changes: 1 addition & 1 deletion sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def flatten_traceable_inputs(self) -> list[VariableBase]:
def call_function(self, *args, **kwargs):
pass

def __getattr__(self, name: str):
def getattr(self, name: str):
if not hasattr(self.value, name):
raise InnerError(
f"{self.__class__.__name__} {self} has no attribute {name}"
Expand Down
14 changes: 10 additions & 4 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from ....utils.exceptions import InnerError
from ..guard import StringifyExpression, union_free_vars
from ..pycode_generator import PyCodeGen
from ..tracker import ConstTracker, DummyTracker, GetAttrTracker, Tracker
from ..tracker import (
ConstTracker,
DanglingTracker,
DummyTracker,
GetAttrTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory

if TYPE_CHECKING:
Expand Down Expand Up @@ -217,7 +223,7 @@ def size(self):
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable.wrap_literal(elements)

def __getattr__(self, name: str):
def getattr(self, name: str):
if name in ["shape", "dtype", "stop_gradient"]:
return VariableFactory.from_value(
getattr(self.meta, name),
Expand All @@ -228,7 +234,7 @@ def __getattr__(self, name: str):
from .callable import TensorFunctionVariable

fn_var = TensorFunctionVariable(
name, graph=self.graph, tracker=DummyTracker([])
name, graph=self.graph, tracker=DanglingTracker()
)
return fn_var.bind(self, name)
elif name in ["T", "ndim", "size"]:
Expand Down Expand Up @@ -351,7 +357,7 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):

class DummyVariable(VariableBase):
def __init__(self):
super().__init__(DummyTracker([]))
super().__init__(DanglingTracker())

def reconstruct(self, codegen: PyCodeGen):
codegen.gen_push_null()
15 changes: 11 additions & 4 deletions sot/opcode_translator/executor/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..guard import StringifyExpression, union_free_vars
from ..tracker import (
ConstTracker,
DanglingTracker,
DummyTracker,
GetAttrTracker,
GetItemTracker,
Expand Down Expand Up @@ -203,11 +204,11 @@ def wrap_method(
# to DummyTracker, and set it to GetAttrTracker after method_var is created.
if instance is None:
instance_var = VariableFactory.from_value(
value.__self__, graph, DummyTracker([])
value.__self__, graph, DanglingTracker()
)
if fn is None:
fn_var = VariableFactory.from_value(
value.__func__, graph, DummyTracker([])
value.__func__, graph, DanglingTracker()
)
assert isinstance(instance_var, VariableBase)
assert isinstance(fn_var, FunctionVariable)
Expand Down Expand Up @@ -308,14 +309,20 @@ def main_info(self) -> dict[str, Any]:
}


class BuiltinVariable(CallableVariable):
class BuiltinVariable(FunctionVariable):
def __init__(
self, fn: Callable[..., Any], graph: FunctionGraph, tracker: Tracker
):
super().__init__(graph, tracker)
from ..dispatcher import Dispatcher

super().__init__(fn, graph, tracker)
self.value = fn
self.dispatcher = Dispatcher()

def call_function(self, *args, **kwargs):
handler = self.dispatcher.dispatch(self.value, *args, **kwargs)
if handler is not None:
return handler(*args, **kwargs)
# TODO(0x45f): For builtin functions, may have 3 different ways to process as below:
# 1. Simulation execution: ensure correct simulation execution and handle trackers with care
# 2. Trigger the paddle api call
Expand Down
19 changes: 15 additions & 4 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@

from ....utils.exceptions import InnerError, NotImplementException
from ..pycode_generator import PyCodeGen
from ..tracker import ConstTracker, DummyTracker, GetItemTracker, Tracker
from ..tracker import (
ConstTracker,
DanglingTracker,
DummyTracker,
GetItemTracker,
Tracker,
)
from .base import ConstTypes, VariableBase, VariableFactory
from .basic import ConstantVariable

Expand Down Expand Up @@ -325,16 +331,21 @@ def override_method_update(self, data):
self.value.update(data.get_wrapped_items())
return self

def __getattr__(self, name):
from .callable import DirectlyCallFunctionVariable
def getattr(self, name):
from .callable import BuiltinVariable, DirectlyCallFunctionVariable

if name == "keys":
return BuiltinVariable(
dict.keys, self.graph, DanglingTracker()
).bind(self, name)

name_ = "override_method_" + name
if hasattr(self, name_):
method = getattr(self, name_)
return DirectlyCallFunctionVariable(
method.__func__,
graph=self.graph,
tracker=DummyTracker([]),
tracker=DanglingTracker(),
).bind(self, name)
else:
raise NotImplementException(
Expand Down