-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SOT][PIR] support numpy
and register_hook
attr
#66008
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
elif name in paddle_tensor_methods: | ||
elif name in get_tensor_methods(): | ||
if name in ["numpy", "regisiter_hook"]: | ||
raise FallbackError(f"no support {name}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里应该是可以跑的,不是说获取这个方法有问题,而是这个方法调用有问题,应该在 call 的时候触发 BreakGraph
不过值得注意的是 place 应该是一个不支持的属性,在获取时就应该 BreakGraph
如果你想问 numpy 和 regisiter_hook 没支持这里难道不会挂么?那么 numpy 的支持方式就是答案,fake interface
python/paddle/pir/math_op_patch.py
Outdated
@@ -142,6 +142,7 @@ def cuda(self, device_id=None, blocking=True): | |||
return _C_ops.memcpy(self, 1) | |||
|
|||
@property | |||
@fake_interface_only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?这个为啥要 fake,这个 fake AST 不直接挂了吗
"persistable", | ||
"stop_gradient", | ||
"place", | ||
]: | ||
if name == "name" and self.meta.name.startswith( | ||
"infer_meta_variable_tmp" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意在 PIR 下中间变量名应该永远是 Value@NoName
,所以这两种情况都需要 breakgraph
可以在 infermeta 里定义一个常量 SOT_INFER_META_INNER_VAR_PREFIX = "___SOT_INFER_META_INNER_VAR"
,以便老 IR 和 PIR name 统一,这里只需要 import 并判断一次就可以了
""" | ||
pass | ||
|
||
def register_hook(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fake_interface_only 装饰呢?
if name in ["name", "place"] and self.meta.is_inner_var(): | ||
raise BreakGraphError( | ||
f"{self.meta.name} is a middle tensor. get {name} property." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这俩可以直接提取到外面,没必要放在这个分支内
python/paddle/jit/sot/infer_meta.py
Outdated
name = "Value@NoName" | ||
else: | ||
name = value.name | ||
name = SOT_INFER_META_INNER_VAR_PREFIX |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其实这样有点怪怪的,因为这样的话所有中间变量 name 都是固定的 SOT_INFER_META_INNER_VAR_PREFIX
,不可能会有后缀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要不这个常量改名 SOT_INFER_META_INNER_VAR
吧,162 行写法改为 self.name == SOT_INFER_META_INNER_VAR
,同时 206 行改为 self.var_name_generator = UniqueNameGenerator(SOT_INFER_META_INNER_VAR)
@@ -626,6 +660,8 @@ def value_hash(self): | |||
('to_dense', to_dense), | |||
('indices', indices), | |||
('values', values), | |||
("numpy", numpy), | |||
("register_hook", register_hook), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test/dygraph_to_static/test_tensor_attr_consistency.py
需要同步清理
python/paddle/pir/math_op_patch.py
Outdated
Value don't have 'register_hook' interface in static graph mode | ||
But this interface can greatly facilitate dy2static. | ||
So we give a warning here and return None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不是 warning 吧?这里是直接报错吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy
regisiter_hook
attrnumpy
and regisiter_hook
attr
numpy
and regisiter_hook
attrnumpy
and regisiter_hook
attr
numpy
and regisiter_hook
attrnumpy
and register_hook
attr
PR Category
Others
PR Types
Others
Description
paddle.pir.Value
替代paddle.static.Variable
Value.numpy
和Value.register_hook