Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ def BINARY_SUBSCR(self, instr):
container = self.pop()
assert isinstance(key, VariableBase)
self._graph.add_global_guarded_variable(key)
self.push(container[key.value])
self.push(
BuiltinVariable(operator.getitem, self._graph, DanglingTracker())(
container, key.value
)
)

# inplace operators
# paddle variable do not have inplace operators. For example when call `y **= x`, will call var.__pow__
Expand Down Expand Up @@ -512,6 +516,15 @@ def STORE_SUBSCR(self, instr):
container[key.get_value()] = value
value.debug_name = f"{container.debug_name}[{key.debug_name}]"

def DELETE_SUBSCR(self, instr):
key = self.pop()
container = self.pop()
assert isinstance(key, VariableBase)
self._graph.add_global_guarded_variable(key)
BuiltinVariable(operator.delitem, self._graph, DanglingTracker())(
container, key
)

def BUILD_LIST(self, instr):
list_size = instr.arg
assert list_size <= len(
Expand Down
66 changes: 66 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,78 @@
{},
lambda var: var.bool(),
)
Dispatcher.register(
bool,
("ConstantVariable",),
{},
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ContainerVariable",),
{},
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ConstantVariable",),
{},
lambda var: var.bool(),
)

# getitem
# TODO: Should pass its Variable into the getitem and perform operations such as getting value in the getitem. like this:https://github.com/PaddlePaddle/PaddleSOT/pull/198#discussion_r1241110949
Dispatcher.register(
operator.getitem,
(
"VariableBase",
"int | str | TensorVariable | slice",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 TensorVariable 是不是会出问题呢?现在是不是没有相关的 case,这个问题算是一个遗留问题,和 66 行的 getattr 应该算同一个问题,在 BINARY_SUBSCR 里我们直接将 value 取出来传进 getitem 了,所以 tensor1[tensor2] 假如 tensor2 是一个中间变量的话(value 为 None),是会有问题的,这里应该将其 Variable 传入 getitem,并在 getitem 里做取 value 等操作

这个 PR 先不改这个问题吧,可以下个 PR 统一修改下 getitem、getattr 的问题,这个 PR 记录一下 TODO 即可

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

),
{},
lambda var, key: var.getitem(key),
)
Dispatcher.register(
operator.getitem,
(
"VariableBase",
"ConstantVariable | SliceVariable",
),
{},
lambda var, key: var.getitem(key.get_value()),
)

# setitem
Dispatcher.register(
operator.setitem,
(
"VariableBase",
"int | str | ConstantVariable | TensorVariable",
"int | str | ConstantVariable | TensorVariable",
),
{},
lambda var, key, value: var.setitem(key.get_value(), value),
)

# delitem
Dispatcher.register(
operator.delitem,
(
"VariableBase",
"int | str | TensorVariable",
),
{},
lambda var, key: var.delitem(key),
)
Dispatcher.register(
operator.delitem,
(
"VariableBase",
"ConstantVariable",
),
{},
lambda var, key: var.delitem(key.get_value()),
)


# TensorVariable
Dispatcher.register(
Expand Down
11 changes: 6 additions & 5 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@ def getattr(self, name: str):
)

def __setitem__(self, key, value):
return self.setitem(key, value)

def setitem(self, key, value):
raise NotImplementException(f"{self} is not support setitem.")

def __repr__(self):
Expand All @@ -291,9 +294,10 @@ def __repr__(self):
def __str__(self):
return self.__repr__()

def __getitem__(self, item):
# TODO: Remove this function after we use builtin dispatcher instead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数不一定能删吧,这里只是我早期认为可以直接替换掉,实际上可能不是这样

def __getitem__(self, idx):
return self.getitem(idx)

def getitem(self, item):
class_var = VariableFactory.from_value(
self.get_value().__class__,
self.graph,
Expand Down Expand Up @@ -331,9 +335,6 @@ def __call__(self, *args, **kwargs):
output = fn_var(*args, **kwargs)
return output

def getitem(self, *args, **kwargs):
pass

@VariableFactory.register_from_value()
def from_value(
value: Any,
Expand Down
9 changes: 7 additions & 2 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def main_info(self) -> dict[str, Any]:
def __bool__(self) -> bool:
return bool(self.value)

def bool(self):
return VariableFactory.from_value(
bool(self), self.graph, DummyTracker([self])
)

@VariableFactory.register_from_value()
def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
if isinstance(value, ConstTypes):
Expand Down Expand Up @@ -188,7 +193,7 @@ def main_info(self) -> dict[str, Any]:
"var_name": self.var_name,
}

def __getitem__(self, key):
def getitem(self, key):
return self.graph.call_tensor_method(
'__getitem__',
self,
Expand All @@ -197,7 +202,7 @@ def __getitem__(self, key):
),
)

def __setitem__(self, key, value):
def setitem(self, key, value):
return self.graph.call_tensor_method(
'__setitem__',
self,
Expand Down
39 changes: 26 additions & 13 deletions sot/opcode_translator/executor/variables/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def len(self):
len(self), self.graph, DummyTracker([self])
)

def __bool__(self):
def __bool__(self) -> bool:
return len(self) > 0

def bool(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(
self.value = val_list

def get_value(self):
return [self[i].get_value() for i in range(len(self))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了方便,这里就不要改了

def __getitem__(self, idx): # 看情况是只给容器加还是 VariableBase 加
    return self.getitem(idx)

与 getattr 不同,getattr 只要有 . 就会进入,很容易出问题,但 __getitem__ 不会

return [self[idx].get_value() for idx in range(len(self))]

def _reconstruct(self, codegen: PyCodeGen):
size = len(self)
Expand All @@ -102,7 +102,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
'''
we need to make sure that:
before an inplace change happens to ListVariable,
Expand All @@ -124,9 +124,9 @@ def __getitem__(self, key):

return retval

def __setitem__(self, key, value):
def setitem(self, key, value):
'''
why __setitem__ is ok:
why setitem is ok:

case:
def f(x = [t0, t1])
Expand All @@ -147,13 +147,18 @@ def f(x = [t0, t1])
f"[{self.__class__.__name__}]: received {value} to set value."
)
self.value[key] = value
return ConstantVariable.wrap_literal(None)

def __delitem__(self, key):
return self.delitem(key)

def delitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: received {key} as key to delete."
)
del self.value[key]
return ConstantVariable.wrap_literal(None)

def extend(self, data):
self.value.extend(data.get_wrapped_items())
Expand Down Expand Up @@ -188,17 +193,16 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker):
class TupleVariable(ContainerVariable):
def __init__(
self,
val_tuple: list[VariableBase],
val_tuple: tuple[VariableBase],
graph: FunctionGraph,
tracker: Tracker,
):
super().__init__(tracker)
self.graph = graph
# exactly it is a list (need replace item with VariableBase)
self.value = list(val_tuple)
self.value = val_tuple

def get_value(self):
return tuple(self[i].get_value() for i in range(len(self)))
return tuple(self[idx].get_value() for idx in range(len(self)))

def _reconstruct(self, codegen: PyCodeGen):
size = len(self)
Expand All @@ -222,7 +226,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
Expand All @@ -233,12 +237,15 @@ def __getitem__(self, key):
retval, graph=self.graph, tracker=GetItemTracker(self, key)
)

def __setitem__(self, key, value):
def setitem(self, key, value):
raise InnerError(
f"[{self.__class__.__name__}]: setitem is not allowed."
)

def __delitem__(self, key):
return self.delitem(key)

def delitem(self, key):
raise InnerError(
f"[{self.__class__.__name__}]: delitem is not allowed."
)
Expand Down Expand Up @@ -312,7 +319,7 @@ def main_info(self) -> dict[str, Any]:
def __len__(self):
return len(self.value)

def __getitem__(self, key):
def getitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
Expand All @@ -324,7 +331,7 @@ def __getitem__(self, key):
retval, self.graph, tracker=GetItemTracker(self, key)
)

def __setitem__(self, key, value):
def setitem(self, key, value):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key."
Expand All @@ -337,12 +344,18 @@ def __setitem__(self, key, value):

self.value[key] = value

return ConstantVariable.wrap_literal(None)

def __delitem__(self, key):
return self.delitem(key)

def delitem(self, key):
if isinstance(key, VariableBase):
raise InnerError(
f"[{self.__class__.__name__}]: recieved {key} as key to delete."
)
del self.value[key]
return ConstantVariable.wrap_literal(None)

def keys(self):
from .iter import SequenceIterVariable
Expand Down
6 changes: 6 additions & 0 deletions tests/test_03_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,15 @@ def foo(x: int, y: paddle.Tensor):
return x[1] + 1


def foo1(x: int, y: paddle.Tensor):
z = (x, y, 3, 4)
return z[0:5:1]


class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(foo1, 1, paddle.to_tensor(2))


if __name__ == "__main__":
Expand Down
41 changes: 39 additions & 2 deletions tests/test_04_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# New Supported Instructions:
# BUILD_LIST (new)
# BINARY_SUBSCR
# DELETE_SUBSCR


import unittest
Expand All @@ -9,14 +11,49 @@
import paddle


def foo(x: int, y: paddle.Tensor):
def list_getitem_int(x: int, y: paddle.Tensor):
x = [x, y]
return x[0] + 1


def list_getitem_tensor(x: int, y: paddle.Tensor):
x = [x, y]
return x[1] + 1


def list_setitem_int(x: int, y: paddle.Tensor):
z = [x, y]
z[0] = 3
return z


def list_setitem_tensor(x: int, y: paddle.Tensor):
z = [x, y]
z[1] = paddle.to_tensor(3)
return z


def list_delitem_int(x: int, y: paddle.Tensor):
z = [x, y]
del z[0]
return z


def list_delitem_tensor(x: int, y: paddle.Tensor):
z = [x, y]
del z[1]
return z


class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2))
self.assert_results(list_setitem_int, 1, paddle.to_tensor(2))
# TODO(SigureMo) SideEffects have not been implemented yet, we need to skip them
# self.assert_results(list_setitem_tensor, 1, paddle.to_tensor(2))
self.assert_results(list_delitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2))


if __name__ == "__main__":
Expand Down
Loading