Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add elementwise maximum/minimum ops #4069

Merged
merged 22 commits into from
Jan 11, 2021
Merged

Conversation

doombeaker
Copy link
Contributor

@doombeaker doombeaker commented Dec 31, 2020

概述

为满足浙工大算法开发需求,增加 math.maximum 及 math.minimum 2 个算子的方向。OneFlow 原有的这两个接口,支持 broadcast,实现反向较为复杂。折中考虑,增加了两个 Op elementwise_maximumelementwise_minimum 分别实现了它们的前向和后向。在 Python 前端根据 xy 的形状,决定调用 elementwise 还是 broadcast 的算子。

换言之,此PR合并后,OneFlow 将支持 elementwise 类型的 maximumminimum 的后向,暂不支持 broadcast 类型的 maximumminimum 的后向。

功能 CheckList

注意 : 功能复选框均为可选项,若未选择,说明理由即可。例如:该 Op 由 Python 接口拼接而成,因此无 SetBatchAxisInferFn Op 注册;再比如:该 Op 无输入,因此无 SetInputArgModifyFn

模板中自带的复选框可留空,但是不能删除。可根据实际情况增加复选框选项。

Op

  • Op SetBatchAxisInferFn
  • Op SetGetSbpFn
  • Op SetInputArgModifyFn,未有特殊需求,没有实现
  • Op 反向梯度注册

Kernel

  • CPU x:float32 y:float32

  • CPU x:float32 y:float32

  • GPU x:float32 y:float32

  • GPU x:float32 y:float32

Python Wrapper

  • Python API 参数检查及异常提示(增加了shape的判断)
  • 接口注释(文档沿用之前已有的未做改变)
  • Example (Example 沿用之前已有的未做改变) 

测试

  • 单机单卡 CPU Test Case
  • 单机单卡 GPU Test Case
  • 单机多卡 CPU Test Case
  • 单机多卡 GPU Test Case
  • 分布式 CPU Test Case
  • 分布式 GPU Test Case

GPU 有效带宽

因为u使用的是 CUDA ELEMENTWISE 模板,所以未测此项。

PR Checklist

  • PR 标题语句通畅,明确表达 PR 内容,适合直接作为新版本发布时的 changelog
  • 代码格式化
  • 已经本地编译通过
  • 已本地针对改动测试
  • 已添加 type 标签:(填写 type 标签名,如 bug, enhancement, purge, feature, documentation)
  • 已添加 component 标签:(填写 component 标签名,如 op, system, eager, build, xla, python, ci, test, tooling, onnx)
  • Draft 转正式 PR 前已请人 Review

@oneflow-ci-bot oneflow-ci-bot removed their request for review January 5, 2021 10:08
@leaves-zwx leaves-zwx self-requested a review January 9, 2021 06:12
@doombeaker doombeaker changed the title start up of ADD grad for maximum and minimum Add elementwise maximum/minimum ops Jan 10, 2021
@doombeaker doombeaker marked this pull request as ready for review January 10, 2021 09:12
@oneflow-ci-bot oneflow-ci-bot removed their request for review January 10, 2021 12:19
@Zailiang Zailiang requested a review from MARD1NO January 11, 2021 06:16
@Zailiang Zailiang requested a review from chengtbf January 11, 2021 06:17
@Zailiang Zailiang self-requested a review January 11, 2021 08:06
@oneflow-ci-bot oneflow-ci-bot self-requested a review January 11, 2021 08:51
@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot January 11, 2021 11:22

} // namespace

namespace user_op {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个下面的 macro define 和 op registration 都不用放在 user op 的 namespace 下,直接放在 oneflow namespace 下就可以了

.InputBind("dz", ctx->FwOp().output_grad("z", 0)) \
.InputBind("x", ctx->FwOp().input("x", 0)) \
.InputBind("y", ctx->FwOp().input("y", 0)) \
.Output("dx") \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.Output("dx")
.Output("dy")

这两个 output 是不是需要分别根据 x.need_grad 和 y.need_grad 来设置的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (x.need_grad()) {
   builder.Output("dx");
}
if (y.need_grad()) {
   builder.Output("dy");
}

}
};
} // namespace

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

template<typename FunctorT, typename T>
inline cudaError_t XimumGrad(FunctorT functor, int64_t n, T* dx, T* dy const T* x, const T* y, const T* dz,
                             cudaStream_t stream) {
  using FactoryT = SimpleFactory<FunctorT>;
  return GenericLauncher<FactoryT, R, A, B, C>::Launch(FactoryT(functor), n, dx, dy, x, y, dz, stream);
}

template<template<typename> class BackwardFunctor, typename T>
__global__ void ElementwiseBackwardGradGpu(int64_t elem_cnt, const T* dz, const T* x, const T* y,
T* dx, T* dy) {
BackwardFunctor<T>::Backward(elem_cnt, dz, x, y, dx, dy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA_1D_LOOP(i, elem_cnt) {
   BackwardFunctor<T>::Backward(dz, x, y, dx, dy);
}

@oneflow-ci-bot oneflow-ci-bot requested review from oneflow-ci-bot and removed request for oneflow-ci-bot January 11, 2021 14:56
@oneflow-ci-bot oneflow-ci-bot merged commit 823603c into master Jan 11, 2021
@oneflow-ci-bot oneflow-ci-bot deleted the add_grad_for_maximum_minimum branch January 11, 2021 17:56
@doombeaker doombeaker mentioned this pull request Jan 12, 2021
3 tasks
liujuncheng pushed a commit that referenced this pull request Jun 3, 2021
* start up of ADD grad for maximum and minimum

* refine batch axis

* add GPU version

* add minimum backward

* add static shape unit test

* add dynamic test

* add sbp and batchaxis infer

* refine files hierarchy

* elementwise maximum and minimum finished

* refine on checking dx/dy if exists

* refine (use template functors)

* refine test case

Co-authored-by: MARD1NO <359521840@qq.com>
Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com>
Co-authored-by: Zailiang <zailiangyu@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Former-commit-id: 823603c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants