-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add elementwise maximum/minimum ops #4069
Conversation
…flow-Inc/oneflow into add_grad_for_maximum_minimum
…flow-Inc/oneflow into add_grad_for_maximum_minimum
|
||
} // namespace | ||
|
||
namespace user_op { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个下面的 macro define 和 op registration 都不用放在 user op 的 namespace 下,直接放在 oneflow namespace 下就可以了
.InputBind("dz", ctx->FwOp().output_grad("z", 0)) \ | ||
.InputBind("x", ctx->FwOp().input("x", 0)) \ | ||
.InputBind("y", ctx->FwOp().input("y", 0)) \ | ||
.Output("dx") \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.Output("dx")
.Output("dy")
这两个 output 是不是需要分别根据 x.need_grad 和 y.need_grad 来设置的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (x.need_grad()) {
builder.Output("dx");
}
if (y.need_grad()) {
builder.Output("dy");
}
} | ||
}; | ||
} // namespace | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
template<typename FunctorT, typename T>
inline cudaError_t XimumGrad(FunctorT functor, int64_t n, T* dx, T* dy const T* x, const T* y, const T* dz,
cudaStream_t stream) {
using FactoryT = SimpleFactory<FunctorT>;
return GenericLauncher<FactoryT, R, A, B, C>::Launch(FactoryT(functor), n, dx, dy, x, y, dz, stream);
}
template<template<typename> class BackwardFunctor, typename T> | ||
__global__ void ElementwiseBackwardGradGpu(int64_t elem_cnt, const T* dz, const T* x, const T* y, | ||
T* dx, T* dy) { | ||
BackwardFunctor<T>::Backward(elem_cnt, dz, x, y, dx, dy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA_1D_LOOP(i, elem_cnt) {
BackwardFunctor<T>::Backward(dz, x, y, dx, dy);
}
* start up of ADD grad for maximum and minimum * refine batch axis * add GPU version * add minimum backward * add static shape unit test * add dynamic test * add sbp and batchaxis infer * refine files hierarchy * elementwise maximum and minimum finished * refine on checking dx/dy if exists * refine (use template functors) * refine test case Co-authored-by: MARD1NO <359521840@qq.com> Co-authored-by: ZZK <42901638+MARD1NO@users.noreply.github.com> Co-authored-by: Zailiang <zailiangyu@gmail.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Former-commit-id: 823603c
概述
为满足浙工大算法开发需求,增加 math.maximum 及 math.minimum 2 个算子的方向。OneFlow 原有的这两个接口,支持 broadcast,实现反向较为复杂。折中考虑,增加了两个 Op
elementwise_maximum
和elementwise_minimum
分别实现了它们的前向和后向。在 Python 前端根据x
和y
的形状,决定调用elementwise
还是broadcast
的算子。换言之,此PR合并后,OneFlow 将支持
elementwise
类型的maximum
和minimum
的后向,暂不支持broadcast
类型的maximum
、minimum
的后向。功能 CheckList
注意 : 功能复选框均为可选项,若未选择,说明理由即可。例如:该 Op 由 Python 接口拼接而成,因此无
SetBatchAxisInferFn
Op 注册;再比如:该 Op 无输入,因此无SetInputArgModifyFn
。模板中自带的复选框可留空,但是不能删除。可根据实际情况增加复选框选项。
Op
Kernel
CPU x:float32 y:float32
CPU x:float32 y:float32
GPU x:float32 y:float32
GPU x:float32 y:float32
Python Wrapper
测试
GPU 有效带宽
因为u使用的是 CUDA ELEMENTWISE 模板,所以未测此项。
PR Checklist
bug, enhancement, purge, feature, documentation
)op, system, eager, build, xla, python, ci, test, tooling, onnx
)