Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Dispatch] use dispatch replace getitem, setitem, delitem and bool #198

Merged
merged 12 commits into from
Jun 27, 2023
Merged
6 changes: 5 additions & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ def BINARY_SUBSCR(self, instr):
container = self.pop()
assert isinstance(key, VariableBase)
self._graph.add_global_guarded_variable(key)
self.push(container[key.value])
self.push(
BuiltinVariable(operator.getitem, self._graph, DanglingTracker())(
container, key.value
)
)

# inplace operators
# paddle variable do not have inplace operators. For example when call `y **= x`, will call var.__pow__
Expand Down
32 changes: 32 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,44 @@
{},
lambda var: var.bool(),
)
Dispatcher.register(
bool,
("ConstantVariable",),
{},
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ContainerVariable",),
{},
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ConstantVariable",),
{},
lambda var: var.bool(),
)

# getitem
Dispatcher.register(
operator.getitem,
(
"VariableBase | TensorVariable | ContainerVariable",
"int | str | TensorVariable | slice",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 TensorVariable 是不是会出问题呢?现在是不是没有相关的 case,这个问题算是一个遗留问题,和 66 行的 getattr 应该算同一个问题,在 BINARY_SUBSCR 里我们直接将 value 取出来传进 getitem 了,所以 tensor1[tensor2] 假如 tensor2 是一个中间变量的话(value 为 None),是会有问题的,这里应该将其 Variable 传入 getitem,并在 getitem 里做取 value 等操作

这个 PR 先不改这个问题吧,可以下个 PR 统一修改下 getitem、getattr 的问题,这个 PR 记录一下 TODO 即可

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

),
{},
lambda var, key: var.getitem(key),
)
Dispatcher.register(
operator.getitem,
(
"VariableBase | TensorVariable | ContainerVariable",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 VariableBase 已经包含了 TensorVariable | ContainerVariable,所以后面的需要删掉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

"ConstantVariable | SliceVariable",
),
{},
lambda var, key: var.getitem(key.get_value()),
)


# VariableBase
Expand Down
8 changes: 3 additions & 5 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,10 @@ def __repr__(self):
def __str__(self):
return self.__repr__()

def __getitem__(self, item):
# TODO: Remove this function after we use builtin dispatcher instead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数不一定能删吧,这里只是我早期认为可以直接替换掉,实际上可能不是这样

def __getitem__(self, idx):
return self.getitem(idx)

def getitem(self, item):
class_var = VariableFactory.from_value(
self.get_value().__class__,
self.graph,
Expand Down Expand Up @@ -331,9 +332,6 @@ def __call__(self, *args, **kwargs):
output = fn_var(*args, **kwargs)
return output

def getitem(self, *args, **kwargs):
pass

@VariableFactory.register_from_value()
def from_value(
value: Any,
Expand Down
7 changes: 6 additions & 1 deletion sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def main_info(self) -> dict[str, Any]:
def __bool__(self) -> bool:
return bool(self.value)

def bool(self):
return VariableFactory.from_value(
bool(self), self.graph, DummyTracker([self])
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, ConstTypes):
Expand Down Expand Up @@ -167,7 +172,7 @@ def main_info(self) -> dict[str, Any]:
"var_name": self.var_name,
}

def __getitem__(self, key):
def getitem(self, key):
return self.graph.call_tensor_method(
'__getitem__',
self,
Expand Down
16 changes: 8 additions & 8 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def len(self):
len(self), self.graph, DummyTracker([self])
)

def __bool__(self):
def __bool__(self) -> bool:
return len(self) > 0

def bool(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.value = val_list

def get_value(self):
return [self[i].get_value() for i in range(len(self))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了方便,这里就不要改了

def __getitem__(self, idx): # 看情况是只给容器加还是 VariableBase 加
    return self.getitem(idx)

与 getattr 不同,getattr 只要有 . 就会进入,很容易出问题,但 __getitem__ 不会

return [self[idx].get_value() for idx in range(len(self))]

def _reconstruct(self, codegen: PyCodeGen):
size = len(self)
Expand All @@ -102,7 +102,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
'''
we need to make sure that:
before an inplace change happens to ListVariable,
Expand Down Expand Up @@ -188,17 +188,17 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
class TupleVariable(ContainerVariable):
def __init__(
self,
val_tuple: list[VariableBase],
val_tuple: list[VariableBase] | tuple[VariableBase],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里就不要允许传入 list[VariableBase] 了,直接强制 tuple[VariableBase] 吧,应该没有什么需求是传入 list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

graph: FunctionGraph,
tracker: Tracker,
):
super().__init__(tracker)
self.graph = graph
# exactly it is a list (need replace item with VariableBase)
self.value = list(val_tuple)
self.value = val_tuple

def get_value(self):
return tuple(self[i].get_value() for i in range(len(self)))
return tuple(self[idx].get_value() for idx in range(len(self)))

def _reconstruct(self, codegen: PyCodeGen):
size = len(self)
Expand All @@ -222,7 +222,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
Expand Down Expand Up @@ -312,7 +312,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
Expand Down
15 changes: 15 additions & 0 deletions tests/test_03_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# BINARY_SUBSCR


import operator
import unittest

from test_case_base import TestCaseBase
Expand All @@ -15,10 +16,24 @@ def foo(x: int, y: paddle.Tensor):
return x[1] + 1


def foo1(x: int, y: paddle.Tensor):
x = (x, y, 1)
return operator.getitem(x, slice(0, 2))


def foo2(x: int, y: paddle.Tensor):
z = (x, y, 3, 4)
return z[0:5:1]


class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(foo1, 1, paddle.to_tensor(2))
self.assert_results(foo2, 1, paddle.to_tensor(2))


if __name__ == "__main__":
unittest.main()
# a = foo2(1, paddle.to_tensor(2))
# print(a)
8 changes: 8 additions & 0 deletions tests/test_04_list.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# New Supported Instructions:
# BUILD_LIST (new)
# BINARY_SUBSCR


import operator
import unittest

from test_case_base import TestCaseBase
Expand All @@ -14,9 +16,15 @@ def foo(x: int, y: paddle.Tensor):
return x[1] + 1


def list_getitem(x: int, y: paddle.Tensor):
z = [x, y]
return operator.getitem(z, 1) + 1


class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(list_getitem, 1, paddle.to_tensor(2))


if __name__ == "__main__":
Expand Down
9 changes: 9 additions & 0 deletions tests/test_05_dict.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#182 我忘记加 binary op、unary op 相关的了,可以在 test_14_operators 加一下 operator.add(x, y) 这类的

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需不需要把list那些也移动到test_14_operators, 比如test_04_list里的list_getitem

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以的

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# BUILD_MAP (new)
# BUILD_CONST_KEY_MAP (new)

import operator
import unittest

from test_case_base import TestCaseBase
Expand All @@ -25,6 +26,11 @@ def dict_set_item(x: int, y: paddle.Tensor):
return z[1]


def dict_get_item(x: int, y: paddle.Tensor):
z = {1: y, 2: y + 1}
return operator.getitem(z, 1)


class TestExecutor(TestCaseBase):
def test_build_map(self):
self.assert_results(build_map, 1, paddle.to_tensor(2))
Expand All @@ -35,6 +41,9 @@ def test_build_const_key_map(self):
def test_dict_set_item(self):
self.assert_results(dict_set_item, 1, paddle.to_tensor(2))

def test_dict_get_item(self):
self.assert_results(dict_get_item, 1, paddle.to_tensor(2))


if __name__ == "__main__":
unittest.main()
20 changes: 16 additions & 4 deletions tests/test_15_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,34 @@
import paddle


def build_slice(x: list, y: paddle.Tensor):
def build_list_slice(x: list, y: paddle.Tensor):
x[2:4] = [0, 1]
return x[0] + y


def build_slice_with_step(x: list, y: paddle.Tensor):
def build_list_slice_with_step(x: list, y: paddle.Tensor):
x[1:5:2] = [0, 1]
return x[0] + y


def build_tuple_slice(x: list, y: paddle.Tensor):
x[2:4] = (0, 1)
return x[0] + y


def build_tuple_slice_with_step(x: list, y: paddle.Tensor):
x[1:5:2] = (0, 1)
return x[0] + y


class TestExecutor(TestCaseBase):
def test_simple(self):
x = list(range(10))
y = paddle.arange(10)
self.assert_results(build_slice, x, y)
self.assert_results(build_slice_with_step, x, y)
self.assert_results(build_list_slice, x, y)
self.assert_results(build_list_slice_with_step, x, y)
self.assert_results(build_tuple_slice, x, y)
self.assert_results(build_tuple_slice_with_step, x, y)


if __name__ == "__main__":
Expand Down