Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify reduce ops #8085

Merged
merged 70 commits into from
May 5, 2022
Merged

Modify reduce ops #8085

merged 70 commits into from
May 5, 2022

Conversation

zhongshsh
Copy link
Contributor

@zhongshsh zhongshsh commented Apr 24, 2022

reduce_ops.py 里的max、min、sum等操作改成Functor直接导出。同时解决 Oneflow-Inc/oneflow-documentation#480

  • sum
  • mean
  • all
  • any
  • prod

prod 原代码对齐参数 *, dtype=None,sum 和 mean 没有对齐这个参数,但是 oneflow 并没有对齐 PyTorch * 后的参数(如典型的 out=None),对齐意义也不大,因此没有补全。
image
image

@zhongshsh zhongshsh requested a review from hjchen2 as a code owner April 24, 2022 13:23
@zhongshsh zhongshsh requested a review from doombeaker as a code owner April 25, 2022 02:29
signature: "Tensor (Tensor x, Int32List axis, Bool keepdims=False) => ReduceAll"
signature: [
"Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll",
"Tensor (Tensor x) => ReduceAllAll"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里这个后缀有点鬼畜,要不重新想个后缀名吧例如ReducexxxFlatten?

Copy link
Contributor Author

@zhongshsh zhongshsh Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已改成 ReducexxxWhole

one::OpBuilder("reduce_sum").Input("input_tensor").Output("output_tensor").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
// const DataType dtype = x->dtype()->data_type();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// const DataType dtype = x->dtype()->data_type();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

@zhongshsh zhongshsh added the op label Apr 25, 2022
MutableAttrMap attrs;
if (axis.empty()) {
std::vector<int32_t> reduce_axis(x->shape()->NumAxes());
const int32_t naxis = x->shape()->NumAxes();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a suggestion!

Suggested change
const int32_t naxis = x->shape()->NumAxes();
const int32_t ndim = x->ndim();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 赞同

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于算子属性名就是 axis,所以还是保持原来的写法更一致些。

image


std::vector<int32_t> reduce_axis(axis.size());
for (int i = 0; i < axis.size(); i++) {
CHECK_GE_OR_RETURN(naxis, axis[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

规范exception的写法,并添加对应测试:参考#8080

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

std::iota(reduce_axis.begin(), reduce_axis.end(), 0);
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", reduce_axis));
} else {
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axis));
CHECK_GE_OR_RETURN(naxis, axis.size())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上的exception的意见

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收到

@@ -703,11 +703,11 @@ def _unbind(self, dim=0):
return flow._C.unbind(self, dim)


def _all(self, dim=None, keepdim=False):
def _all(self, dim=[], keepdim=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉不用改?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

以如下接口为例:

  signature: [
    "Tensor (Tensor x, Int32List[1] dim, Bool keepdim=False) => ReduceAll",
    "Tensor (Tensor x) => ReduceAllWhole"
]

传None进不了Int32List,而另一个接口不传入dim参数,所以用None会报错,应该传入空数组。

@@ -38,6 +38,22 @@ namespace oneflow {
namespace one {
namespace functional {

namespace {
std::string exception_check(int32_t base, int32_t value, bool check_ge = true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为什么要加呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还没改好

@github-actions
Copy link
Contributor

github-actions bot commented May 2, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/

@github-actions
Copy link
Contributor

github-actions bot commented May 2, 2022

CI failed when running job: cpu-misc. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label May 2, 2022
@github-actions
Copy link
Contributor

github-actions bot commented May 2, 2022

Speed stats:

@zhongshsh zhongshsh requested review from oneflow-ci-bot and removed request for oneflow-ci-bot May 3, 2022 11:48
@github-actions
Copy link
Contributor

github-actions bot commented May 3, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/

@github-actions
Copy link
Contributor

github-actions bot commented May 3, 2022

Speed stats:
GPU Name: NVIDIA GeForce GTX 1080 

❌ OneFlow resnet50 time: 129.3ms (= 12925.9ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 142.8ms (= 14279.7ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.10 (= 142.8ms / 129.3ms)

OneFlow resnet50 time: 83.8ms (= 8377.9ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 83.6ms (= 8363.2ms / 100, input_shape=[8, 3, 224, 224])
❌ Relative speed: 1.00 (= 83.6ms / 83.8ms)

OneFlow resnet50 time: 54.3ms (= 10858.2ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 55.1ms (= 11025.3ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.02 (= 55.1ms / 54.3ms)

OneFlow resnet50 time: 42.6ms (= 8525.0ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 48.3ms (= 9662.9ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.13 (= 48.3ms / 42.6ms)

OneFlow resnet50 time: 38.1ms (= 7629.5ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.8ms (= 7559.0ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 0.99 (= 37.8ms / 38.1ms)

OneFlow swin dataloader time: 0.256s (= 51.126s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 30.057s / 200, num_workers=1)
Relative speed: 0.588 (= 0.150s / 0.256s)

OneFlow swin dataloader time: 0.067s (= 13.370s / 200, num_workers=4)
PyTorch swin dataloader time: 0.043s (= 8.551s / 200, num_workers=4)
Relative speed: 0.640 (= 0.043s / 0.067s)

OneFlow swin dataloader time: 0.036s (= 7.218s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.395s / 200, num_workers=8)
Relative speed: 0.609 (= 0.022s / 0.036s)

❌ OneFlow resnet50 time: 145.5ms (= 14549.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 167.8ms (= 16776.4ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.15 (= 167.8ms / 145.5ms)

OneFlow resnet50 time: 97.5ms (= 9747.4ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 122.9ms (= 12286.2ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.26 (= 122.9ms / 97.5ms)

OneFlow resnet50 time: 75.3ms (= 15052.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 87.9ms (= 17577.8ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.17 (= 87.9ms / 75.3ms)

OneFlow resnet50 time: 64.8ms (= 12969.6ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 85.4ms (= 17075.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.32 (= 85.4ms / 64.8ms)

OneFlow resnet50 time: 55.3ms (= 11065.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 75.4ms (= 15077.8ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.36 (= 75.4ms / 55.3ms)

@github-actions
Copy link
Contributor

github-actions bot commented May 3, 2022

CI failed when running job: cuda-speed-test. PR label automerge has been removed

@github-actions github-actions bot removed the automerge label May 3, 2022
@zhongshsh zhongshsh removed the request for review from oneflow-ci-bot May 4, 2022 07:07
@zhongshsh zhongshsh requested a review from oneflow-ci-bot May 4, 2022 11:48
@github-actions
Copy link
Contributor

github-actions bot commented May 4, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/

@github-actions
Copy link
Contributor

github-actions bot commented May 4, 2022

CI failed when running job: cuda-benchmark. PR label automerge has been removed

@github-actions
Copy link
Contributor

github-actions bot commented May 5, 2022

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/8085/

@github-actions
Copy link
Contributor

github-actions bot commented May 5, 2022

Speed stats:
GPU Name: NVIDIA GeForce GTX 1080 

❌ OneFlow resnet50 time: 129.4ms (= 12943.0ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 140.7ms (= 14074.9ms / 100, input_shape=[16, 3, 224, 224])
✔️ Relative speed: 1.09 (= 140.7ms / 129.4ms)

OneFlow resnet50 time: 83.5ms (= 8346.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 85.2ms (= 8517.4ms / 100, input_shape=[8, 3, 224, 224])
❌ Relative speed: 1.02 (= 85.2ms / 83.5ms)

OneFlow resnet50 time: 51.8ms (= 10367.6ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 54.0ms (= 10790.8ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.04 (= 54.0ms / 51.8ms)

OneFlow resnet50 time: 41.6ms (= 8313.7ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 41.6ms (= 8316.8ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.00 (= 41.6ms / 41.6ms)

OneFlow resnet50 time: 37.0ms (= 7405.1ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 37.8ms (= 7553.1ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.02 (= 37.8ms / 37.0ms)

OneFlow swin dataloader time: 0.253s (= 50.523s / 200, num_workers=1)
PyTorch swin dataloader time: 0.151s (= 30.230s / 200, num_workers=1)
Relative speed: 0.598 (= 0.151s / 0.253s)

OneFlow swin dataloader time: 0.068s (= 13.692s / 200, num_workers=4)
PyTorch swin dataloader time: 0.044s (= 8.825s / 200, num_workers=4)
Relative speed: 0.644 (= 0.044s / 0.068s)

OneFlow swin dataloader time: 0.038s (= 7.522s / 200, num_workers=8)
PyTorch swin dataloader time: 0.022s (= 4.392s / 200, num_workers=8)
Relative speed: 0.584 (= 0.022s / 0.038s)

❌ OneFlow resnet50 time: 146.5ms (= 14645.7ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 169.2ms (= 16924.0ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.16 (= 169.2ms / 146.5ms)

OneFlow resnet50 time: 99.4ms (= 9941.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 112.3ms (= 11234.3ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.13 (= 112.3ms / 99.4ms)

OneFlow resnet50 time: 77.8ms (= 15557.6ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 88.3ms (= 17667.2ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.14 (= 88.3ms / 77.8ms)

OneFlow resnet50 time: 65.3ms (= 13064.5ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.1ms (= 15417.2ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.18 (= 77.1ms / 65.3ms)

OneFlow resnet50 time: 55.8ms (= 11166.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 72.0ms (= 14406.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.29 (= 72.0ms / 55.8ms)

@zhongshsh zhongshsh merged commit f86de84 into master May 5, 2022
@zhongshsh zhongshsh deleted the modify_reduce_ops branch May 5, 2022 06:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants