-
Notifications
You must be signed in to change notification settings - Fork 24
[Dispatch] use dispatch replace getitem
, setitem
, delitem
and bool
#198
Changes from 5 commits
82d41ad
239bcec
7a1bdd9
0b27dcd
8855b0e
f539545
5129c96
c595f30
3c73305
d828024
32dc2dc
61bb321
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,12 +90,44 @@ | |
{}, | ||
lambda var: var.bool(), | ||
) | ||
Dispatcher.register( | ||
bool, | ||
("ConstantVariable",), | ||
{}, | ||
lambda var: var.bool(), | ||
) | ||
Dispatcher.register( | ||
operator.truth, | ||
("ContainerVariable",), | ||
{}, | ||
lambda var: var.bool(), | ||
) | ||
Dispatcher.register( | ||
operator.truth, | ||
("ConstantVariable",), | ||
{}, | ||
lambda var: var.bool(), | ||
) | ||
|
||
# getitem | ||
Dispatcher.register( | ||
operator.getitem, | ||
( | ||
"VariableBase | TensorVariable | ContainerVariable", | ||
"int | str | TensorVariable | slice", | ||
), | ||
{}, | ||
lambda var, key: var.getitem(key), | ||
) | ||
Dispatcher.register( | ||
operator.getitem, | ||
( | ||
"VariableBase | TensorVariable | ContainerVariable", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
"ConstantVariable | SliceVariable", | ||
), | ||
{}, | ||
lambda var, key: var.getitem(key.get_value()), | ||
) | ||
|
||
|
||
# VariableBase | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -291,9 +291,10 @@ def __repr__(self): | |
def __str__(self): | ||
return self.__repr__() | ||
|
||
def __getitem__(self, item): | ||
# TODO: Remove this function after we use builtin dispatcher instead | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个函数不一定能删吧,这里只是我早期认为可以直接替换掉,实际上可能不是这样 |
||
def __getitem__(self, idx): | ||
return self.getitem(idx) | ||
|
||
def getitem(self, item): | ||
class_var = VariableFactory.from_value( | ||
self.get_value().__class__, | ||
self.graph, | ||
|
@@ -331,9 +332,6 @@ def __call__(self, *args, **kwargs): | |
output = fn_var(*args, **kwargs) | ||
return output | ||
|
||
def getitem(self, *args, **kwargs): | ||
pass | ||
|
||
@VariableFactory.register_from_value() | ||
def from_value( | ||
value: Any, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ def len(self): | |
len(self), self.graph, DummyTracker([self]) | ||
) | ||
|
||
def __bool__(self): | ||
def __bool__(self) -> bool: | ||
return len(self) > 0 | ||
|
||
def bool(self): | ||
|
@@ -78,7 +78,7 @@ def __init__( | |
self.value = val_list | ||
|
||
def get_value(self): | ||
return [self[i].get_value() for i in range(len(self))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了方便,这里就不要改了 def __getitem__(self, idx): # 看情况是只给容器加还是 VariableBase 加
return self.getitem(idx) 与 getattr 不同,getattr 只要有 |
||
return [self[idx].get_value() for idx in range(len(self))] | ||
|
||
def _reconstruct(self, codegen: PyCodeGen): | ||
size = len(self) | ||
|
@@ -102,7 +102,7 @@ def main_info(self) -> dict[str, Any]: | |
def __len__(self): | ||
return len(self.value) | ||
|
||
def __getitem__(self, key): | ||
def getitem(self, key): | ||
''' | ||
we need to make sure that: | ||
before an inplace change happens to ListVariable, | ||
|
@@ -188,17 +188,17 @@ def from_value(value: Any, graph: FunctionGraph | None, tracker: Tracker): | |
class TupleVariable(ContainerVariable): | ||
def __init__( | ||
self, | ||
val_tuple: list[VariableBase], | ||
val_tuple: list[VariableBase] | tuple[VariableBase], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里就不要允许传入 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
graph: FunctionGraph, | ||
tracker: Tracker, | ||
): | ||
super().__init__(tracker) | ||
self.graph = graph | ||
# exactly it is a list (need replace item with VariableBase) | ||
self.value = list(val_tuple) | ||
self.value = val_tuple | ||
|
||
def get_value(self): | ||
return tuple(self[i].get_value() for i in range(len(self))) | ||
return tuple(self[idx].get_value() for idx in range(len(self))) | ||
|
||
def _reconstruct(self, codegen: PyCodeGen): | ||
size = len(self) | ||
|
@@ -222,7 +222,7 @@ def main_info(self) -> dict[str, Any]: | |
def __len__(self): | ||
return len(self.value) | ||
|
||
def __getitem__(self, key): | ||
def getitem(self, key): | ||
if isinstance(key, VariableBase): | ||
raise InnerError( | ||
f"[{self.__class__.__name__}]: recieved {key} as key." | ||
|
@@ -312,7 +312,7 @@ def main_info(self) -> dict[str, Any]: | |
def __len__(self): | ||
return len(self.value) | ||
|
||
def __getitem__(self, key): | ||
def getitem(self, key): | ||
if isinstance(key, VariableBase): | ||
raise InnerError( | ||
f"[{self.__class__.__name__}]: recieved {key} as key." | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. #182 我忘记加 binary op、unary op 相关的了,可以在 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 需不需要把 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以的 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 TensorVariable 是不是会出问题呢?现在是不是没有相关的 case,这个问题算是一个遗留问题,和 66 行的 getattr 应该算同一个问题,在 BINARY_SUBSCR 里我们直接将 value 取出来传进 getitem 了,所以
tensor1[tensor2]
假如 tensor2 是一个中间变量的话(value 为 None),是会有问题的,这里应该将其 Variable 传入 getitem,并在 getitem 里做取 value 等操作这个 PR 先不改这个问题吧,可以下个 PR 统一修改下 getitem、getattr 的问题,这个 PR 记录一下 TODO 即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done