Skip to content

Conversation

huangxu96
Copy link
Contributor

@huangxu96 huangxu96 commented Feb 16, 2022

PR types

Performance optimization

PR changes

OPs

Describe

通过elementwise 接口优化了wehere_op和abs_grad_op。 elementwise 接口打包了一系列性能优化技巧,对于有elementwise行为的op有通用的性能提升。通过重写functor的形式,将代码里的循环遍历元素改写为通过elementwise接口调用functor实现。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@huangxu96 huangxu96 changed the title Optimize the where_op by the elementwise_op funtion Optimize where_op and abs_grad_op by the elementwise interface Feb 23, 2022
AnnaTrainingG
AnnaTrainingG previously approved these changes Feb 23, 2022
#include "paddle/phi/kernels/abs_grad_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi下不能include fluid路径下的文件,参考cast 修改一下

Copy link
Contributor

@AnnaTrainingG AnnaTrainingG Feb 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR描述里面介绍清楚一点做的工作,比如:添加哪些functor,调用哪个Kernel等

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto functor = CondFunctor<T>();
std::vector<const framework::Tensor*> ins = {condition, X, Y};
std::vector<framework::Tensor*> outs = {out};
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议改成phi::funcs的那种调用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
numel, cond_data, x_data, y_data, out_data);
auto functor = CondFunctor<T>();
std::vector<const framework::Tensor*> ins = {condition, X, Y};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关的framework Tensor后续可以改成DensorTensor


namespace paddle {
namespace operators {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR里每个函数加上功能说明

@AnnaTrainingG AnnaTrainingG merged commit c969955 into PaddlePaddle:develop Feb 24, 2022
};

template <typename T>
__global__ void WhereCUDAKernel(const int N, const bool* cond, const T* x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数可以删除了?


template <typename T>
struct CondFunctor {
HOSTDEVICE inline CondFunctor() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认构造函数,可以不用显式写。


template <typename T>
struct AbsGradCUDAFunctor {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认构造函数可以不用显式定义。

};

template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<float>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functor定义可以简化下,参考:

template <typename T>
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::funcs::Real<T>>> {
__device__ __forceinline__ phi::funcs::Real<T> operator()(const T x) const {
return abs(x);
}
};

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants