-
Notifications
You must be signed in to change notification settings - Fork 5.9k
【Hackathon 5th No.27】为 Paddle 新增 select_scatter API -part #59343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从后续适配组合算子、AI编译器、接入新硬件等角度考虑,目前尽量是非必要不新增算子。由于这个算子的功能实际上是set_value
的子集,不建议再新增算子,这个我们可以在RFC中进一步讨论,统一意见以后再实际写实现代码。
当然,从这个PR的kernel逻辑实现来看,预估会比set_value
的kernel逻辑更高效。如果有兴趣的话,可以基于这个PR中kernel的实现逻辑,直接在set_value kernel中修改及优化,不过这已经超出了当前任务的范畴。
… select_scatter rebase and impl as RFC
@zoooo0820 已经按RFC实现了该接口,麻烦您有空的时候review一下。 |
@zoooo0820 CI都过了,麻烦您review一下 |
python/paddle/tensor/manipulation.py
Outdated
) | ||
|
||
# map var to the new output | ||
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分是动转静用来适配python__setitem__
内置返回值为None的,这里API可以正常返回output,所以这句不需要了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
res = paddle.select_scatter(x_tensor, value_tensor, 1, 1) | ||
except Exception as error: | ||
self.assertIsInstance(error, RuntimeError) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里查错辛苦直接用 self.assertRaises
吧,否则后续如果某些修改导致try中语句能过了,这个单测会检测不到。
此外,这两个case能拆成两个函数吗,命名上再清晰一些说明是检查的什么case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/tensor/manipulation.py
Outdated
dtype=x.dtype | ||
) | ||
else: | ||
output = helper.create_variable_for_type_inference(dtype=x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分也是动转静用来适配python__setitem__
内置返回值为None的,参考其他API写法,用else的分支即可
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@zoooo0820 CI已过,辛苦您再review一下 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
python/paddle/tensor/manipulation.py
Outdated
Embeds the values of the values tensor into x at the given index of axis. | ||
Args: | ||
x (Tensor) : The Destination Tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should decribe x
support what kinds of data types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@jeff41404 修改doc之后LLM的CI不知道为什么一直过不了,coverage显示是我没修改的部分没覆盖到 |
|
python/paddle/tensor/manipulation.py
Outdated
Embeds the values of the values tensor into x at the given index of axis. | ||
Args: | ||
x (Tensor) : The Destination Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex64`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex64
, complex64
should be complex64
, complex128
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
python/paddle/tensor/manipulation.py
Outdated
Args: | ||
x (Tensor) : The Destination Tensor. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex64`. | ||
values (Tensor) : The tensor to embed into x. Supported data types are `bool`, `float16`, `float32`, `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`, `bfloat16`, `complex64`, `complex64`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complex64
, complex64
should be complex64
, complex128
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
可以提交对应的中文文档 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Description
#57262
新增select_scatter API
RFC:PaddlePaddle/community#757