-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Paddle Tensorrt] add pd_op.assign,pd_op.assign_value_,pd_op.assign_out converter #68775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
def setUp(self): | ||
self.python_api = paddle.assign | ||
self.api_args = { | ||
"x": np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
converter里写了4种数据类型, 单测应该覆盖全面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
dtype = np.bool | ||
|
||
constant_layer = network.add_constant( | ||
shape, np.array(values, dtype=np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥固定死数据类型是float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写错了,我改下
def setUp(self): | ||
self.python_api = paddle.split | ||
self.api_args = { | ||
"x": np.random.randn(1, 2).astype(np.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同样的, 这个单测似乎过于简单, 应该考虑到更多的case, 例如num_or_sections=1或者等于shape的值会是什么结果?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前面都测过了,这个是补充的有bug场景的组网
return identity_layer.get_output(0) | ||
|
||
|
||
@converter_registry.register("pd_op.assign_value_", trt_version="8.x") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测似乎没有测试pd_op.assign_value_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TestAssignValueFloat32TRTPattern这个是
PR Category
Inference
PR Types
New features
Description
card-71500
添加了pd_op.assign,pd_op.assign_value_,pd_op.assign_out converter,以及修复了pd_op.split_with_num的bug