-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify reduce ops #8085
Modify reduce ops #8085
Conversation
signature: "Tensor (Tensor x, Int32List axis, Bool keepdims=False) => ReduceAll" | ||
signature: [ | ||
"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll", | ||
"Tensor (Tensor x) => ReduceAllAll" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里这个后缀有点鬼畜,要不重新想个后缀名吧例如ReducexxxFlatten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已改成 ReducexxxWhole
one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build()); | ||
} | ||
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const { | ||
// const DataType dtype = x->dtype()->data_type(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// const DataType dtype = x->dtype()->data_type(); | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
MutableAttrMap attrs; | ||
if (axis.empty()) { | ||
std::vector<int32_t> reduce_axis(x->shape()->NumAxes()); | ||
const int32_t naxis = x->shape()->NumAxes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a suggestion!
const int32_t naxis = x->shape()->NumAxes(); | |
const int32_t ndim = x->ndim(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 赞同
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
std::vector<int32_t> reduce_axis(axis.size()); | ||
for (int i = 0; i < axis.size(); i++) { | ||
CHECK_GE_OR_RETURN(naxis, axis[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
规范exception的写法,并添加对应测试:参考#8080
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
std::iota(reduce_axis.begin(), reduce_axis.end(), 0); | ||
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", reduce_axis)); | ||
} else { | ||
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis)); | ||
CHECK_GE_OR_RETURN(naxis, axis.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上的exception的意见
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收到
@@ -703,11 +703,11 @@ def _unbind(self, dim=0): | |||
return flow._C.unbind(self, dim) | |||
|
|||
|
|||
def _all(self, dim=None, keepdim=False): | |||
def _all(self, dim=[], keepdim=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里感觉不用改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以如下接口为例:
signature: [
"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll",
"Tensor (Tensor x) => ReduceAllWhole"
]
传None进不了Int32List,而另一个接口不传入dim参数,所以用None会报错,应该传入空数组。
@@ -38,6 +38,22 @@ namespace oneflow { | |||
namespace one { | |||
namespace functional { | |||
|
|||
namespace { | |||
std::string exception_check(int32_t base, int32_t value, bool check_ge = true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个为什么要加呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
还没改好
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/ |
CI failed when running job: cpu-misc. PR label automerge has been removed |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/ |
Speed stats:
|
CI failed when running job: cuda-speed-test. PR label automerge has been removed |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/ |
CI failed when running job: cuda-benchmark. PR label automerge has been removed |
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/ |
Speed stats:
|
reduce_ops.py 里的max、min、sum等操作改成Functor直接导出。同时解决 Oneflow-Inc/oneflow-documentation#480
prod 原代码对齐参数 *, dtype=None,sum 和 mean 没有对齐这个参数,但是 oneflow 并没有对齐 PyTorch * 后的参数(如典型的 out=None),对齐意义也不大,因此没有补全。

