-
Notifications
You must be signed in to change notification settings - Fork 267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5 No.16】为 Paddle 新增 EmbeddingBag API #688
Conversation
add api design for embedding_bag
仿照之前PR中的设计: | ||
|
||
```Python | ||
def embedding_bag(input, params, weight, mode, name=None) -> Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿是否缺少了一个sparse参数,params参数是否可以改成padding_idx命名 跟embedding保持一致
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,缺少的sparse参数后续会补上;
params参数为了和embedding保持一致预期改为weight;
原本的weight为了和pytorch一致预期更名为per_sample_weight
""" | ||
``` | ||
|
||
## API 实现方案 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前的pr中没有考虑sparse的情况,这儿实现需要考虑sparse相关实现
# 五、设计思路与实现方案 | ||
|
||
## 命名与参数设计 | ||
仿照之前PR中的设计: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这儿补充一下class相关的描述吧
## API 实现方案 | ||
总体实现参考[增加C++算子教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/api_contributing_guides/new_cpp_op_cn.html) | ||
|
||
1. 分别对CPU和GPU环境下增加kernel实现,确保CI均编译通过 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试列一下方案吧,比如对哪些边界条件的测试。
PR types
Others
PR changes
Docs
Description
【Hackathon 5 No.16】为 Paddle 新增 EmbeddingBag API