Skip to content

[Pten] Add reduce mean kernel, replace with mean API #37559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e1ea408
add pten reduce kernel
MingMingShangTian Nov 22, 2021
4facf3e
add reduce_sum kernel
MingMingShangTian Nov 22, 2021
b4b0d62
update attribute args and order
MingMingShangTian Nov 22, 2021
1e8e06c
make out dtype undefined
MingMingShangTian Nov 22, 2021
366a3ff
fix empty input error
MingMingShangTian Nov 22, 2021
253dc18
merge develop branch
MingMingShangTian Nov 23, 2021
0033c69
merge develop branch
MingMingShangTian Nov 23, 2021
e67cd57
Merge branch 'develop' into pten_reduce_kernel
MingMingShangTian Nov 23, 2021
3383c3f
rename sum as reduce function
MingMingShangTian Nov 23, 2021
13423cb
rename sum as reduce function
MingMingShangTian Nov 23, 2021
8ab3a40
Merge branch 'develop' into pten_reduce_kernel
MingMingShangTian Nov 23, 2021
3b3ec9b
fix reducekernelImpl args error
MingMingShangTian Nov 23, 2021
e2a5f4b
add reduce cuda kernel
MingMingShangTian Nov 24, 2021
2e0343d
merge develop branch
MingMingShangTian Nov 24, 2021
d4969df
modify dims type to const &
MingMingShangTian Nov 24, 2021
197e62c
remove unsed log
MingMingShangTian Nov 24, 2021
5e6cb33
fix reduce_all out eigen function error
MingMingShangTian Nov 24, 2021
c376481
remove unused codes
MingMingShangTian Nov 25, 2021
d200fd6
add the missing sum api define and testcase
MingMingShangTian Nov 25, 2021
33c1b3d
merge develop branch
MingMingShangTian Nov 25, 2021
63e4391
merge develop branch
MingMingShangTian Nov 25, 2021
788e49f
fix sum test axis value error
MingMingShangTian Nov 25, 2021
fa512f4
replace pten mean kernel with reduce_mean
MingMingShangTian Nov 25, 2021
a3437a6
revcover meam cuda to original implement
MingMingShangTian Nov 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 45 additions & 4 deletions paddle/fluid/operators/mean_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ namespace cub = hipcub;
namespace paddle {
namespace operators {

template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n)
: n_inv(static_cast<T>(1.0 / n)) {}

HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

private:
T n_inv;
};

template <typename T>
__global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
Expand All @@ -34,6 +45,37 @@ __global__ void MeanRunKernel(const T* in_data, T* out_data, int N) {
}
}

template <typename DeviceContext, typename T>
class MeanCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");

output->mutable_data<T>(context.GetPlace());
auto size_prob = input->numel();
const T* in_data = input->data<T>();
T* out_data = output->mutable_data<T>(context.GetPlace());
auto stream = context.cuda_device_context().stream();

DivideFunctor<T> transformer(size_prob);
cub::TransformInputIterator<T, DivideFunctor<T>, const T*> trans_x(
in_data, transformer);
size_t temp_storage_bytes = 0;

auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
context.GetPlace());
err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x,
out_data, size_prob, stream);
PADDLE_ENFORCE_CUDA_SUCCESS(err);
}
};

template <typename DeviceContext, typename T>
class MeanCUDAGradKernel : public framework::OpKernel<T> {
public:
Expand Down Expand Up @@ -62,11 +104,10 @@ class MeanCUDAGradKernel : public framework::OpKernel<T> {

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
mean, ops::MeanKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanKernel<paddle::platform::CUDADeviceContext, plat::float16>);
mean, ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::MeanCUDAKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
mean_grad,
ops::MeanCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down
43 changes: 9 additions & 34 deletions paddle/fluid/operators/mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/top/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"

namespace paddle {
namespace operators {
Expand All @@ -33,40 +27,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

/** [ Why still keep the original kernel implementation? ]
*
* Removal of the original kernel implementation and kernel registration needs
* to ensure that the new kernel mechanism adapts to multiple sets of execution
* mechanisms, including:
*
* 1. Executor and ParallelExecutor
* 2. Dygraph OpBase (Tracer and Engine)
* 3. New Executor
* 4. Predictor
* 5. NPU and XPU lack kernel and need to reuse CPU Kernel
*
* Removal of the original Kernel requires a more complete solution to ensure
* that it will not affect the current execution system.
* Currently, only the first two cases are adapted.
*
* The principle here is that the implementation in the kernel must reuse the
* corresponding functions in the Tensor Operation library and cannot maintain
* two copies of the code.
*/
template <typename DeviceContext, typename T>
class MeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data<T>(x->place());
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");

output->mutable_data<T>(context.GetPlace());

auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();

// call new kernel
pten::Mean<T>(dev_ctx, *pt_x.get(), pt_out.get());
y.device(place) = X.mean();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mean.cu中的逻辑也需要恢复

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意check下kernel注册的写法也要恢复到和原先一样

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,thx

}
};

Expand Down
55 changes: 24 additions & 31 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"

// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/math.h"
#include "paddle/pten/kernels/functions/general/reduce_impl.h"

#if defined(__HIPCC__) || defined(__NVCC__)
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#endif
Expand Down Expand Up @@ -232,43 +239,29 @@ class ReduceKernel : public framework::OpKernel<T> {
bool keep_dim = context.Attr<bool>("keep_dim");
int out_dtype = context.Attr<int>("out_dtype");
framework::proto::VarType::Type cast_out_dtype;

// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
auto* input = context.Input<Tensor>("X");

if (out_dtype < 0) {
auto* cast_input = context.Input<Tensor>("X");
cast_out_dtype =
static_cast<framework::proto::VarType::Type>(cast_input->type());
framework::VisitDataType(
cast_out_dtype,
ReduceKernelFunctor<DeviceContext, T, Functor>(
cast_input, output, dims, keep_dim, reduce_all, context));
static_cast<framework::proto::VarType::Type>(input->type());
} else {
Tensor tmp_tensor;
cast_out_dtype = static_cast<framework::proto::VarType::Type>(out_dtype);
auto* input = context.Input<Tensor>("X");

tmp_tensor.Resize(input->dims());
framework::VisitDataType(
cast_out_dtype,
CastOpFunctor<DeviceContext, T>(
input, &tmp_tensor,
context.template device_context<DeviceContext>()));
framework::VisitDataType(
cast_out_dtype,
ReduceKernelFunctor<DeviceContext, T, Functor>(
&tmp_tensor, output, dims, keep_dim, reduce_all, context));
}

auto& dev_ctx = context.device_context<DeviceContext>();
output->mutable_data(
dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(cast_out_dtype));

auto pt_x = paddle::experimental::MakePtenDenseTensor(*input);
auto pt_out = paddle::experimental::MakePtenDenseTensor(*output);

std::vector<int64_t> tmp_dims(dims.begin(), dims.end());

// call new kernel
pten::general::Reduce<DeviceContext, T, Functor>(
dev_ctx, *pt_x.get(), reduce_all, tmp_dims, keep_dim,
pten::TransToPtenDataType(cast_out_dtype), pt_out.get());
}
};
template <typename DeviceContext, typename OutT, typename Functor>
Expand Down
9 changes: 8 additions & 1 deletion paddle/pten/api/include/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ namespace experimental {

// TODO(chenweihang): add scale API
// TODO(chenweihang): move mean API into stat.h/cc
PD_DLL_DECL Tensor mean(const Tensor& x);
PD_DLL_DECL Tensor mean(const Tensor& x,
const std::vector<int64_t>& axis,
bool keep_dim);
Comment on lines +24 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需不需要再提供一个带有默认值的API接口?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续PR 再加默认值的接口


PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);

Expand All @@ -31,5 +33,10 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y);

PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y);

PD_DLL_DECL Tensor sum(const Tensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim);

Comment on lines +36 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续PR 再加默认值的接口

} // namespace experimental
} // namespace paddle
73 changes: 70 additions & 3 deletions paddle/pten/api/lib/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ PT_DECLARE_MODULE(MathCUDA);
namespace paddle {
namespace experimental {

PD_DLL_DECL Tensor mean(const Tensor& x) {
PD_DLL_DECL Tensor mean(const Tensor& x,
const std::vector<int64_t>& axis,
bool keep_dim) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"mean", kernel_key);
"reduce_mean", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
Expand All @@ -50,8 +52,73 @@ PD_DLL_DECL Tensor mean(const Tensor& x) {
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);

// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;

DataType out_dtype = DataType::UNDEFINED;

kernel_context.EmplaceBackAttr(axis);
kernel_context.EmplaceBackAttr(keep_dim);
kernel_context.EmplaceBackAttr(reduce_all);
kernel_context.EmplaceBackAttr(dense_x->dtype());
kernel_context.EmplaceBackAttr(out_dtype);

// 4. InferShape
auto out_meta = ReduceInferMeta(dense_x->meta(), axis, keep_dim);

// 5. Prepare outputs
Tensor out;
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
pten::TransToFluidPlace(kernel_key.backend()));
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
kernel_context.EmplaceBackOutput(dense_out);
out.set_impl(dense_out);

// 6. Call kernel
kernel(&kernel_context);

return out;
}

PD_DLL_DECL Tensor sum(const Tensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim) {
// 1. Get kernel signature and kernel
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"reduce_sum", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackInput(dense_x);

// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;

DataType out_dtype = DataType::UNDEFINED;
if (dense_x->dtype() == DataType::BOOL ||
dense_x->dtype() == DataType::INT32 ||
dense_x->dtype() == DataType::INT64) {
out_dtype = DataType::INT64;
}

kernel_context.EmplaceBackAttr(axis);
kernel_context.EmplaceBackAttr(keep_dim);
kernel_context.EmplaceBackAttr(reduce_all);
kernel_context.EmplaceBackAttr(dense_x->dtype());
kernel_context.EmplaceBackAttr(out_dtype);

// 4. InferMeta
auto out_meta = ReductionInferMeta(dense_x->meta());
auto out_meta = ReduceInferMeta(dense_x->meta(), axis, keep_dim);

// 5. Prepare outputs
Tensor out;
Expand Down
37 changes: 34 additions & 3 deletions paddle/pten/include/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,44 @@ DenseTensor Sign(const ContextT& dev_ctx, const DenseTensor& x) {
}

template <typename T, typename ContextT>
DenseTensor Mean(const ContextT& dev_ctx, const DenseTensor& x) {
auto out_meta = ReductionInferMeta(x.meta());
DenseTensor Mean(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
Mean<T>(dev_ctx, x, &dense_out);
bool reduce_all = false;
DataType out_dtype = pten::DataType::UNDEFINED;
Mean<T>(
dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), out_dtype, &dense_out);
return dense_out;
}

template <typename T, typename ContextT>
DenseTensor Sum(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& axis,
DataType dtype,
bool keep_dim) {
auto out_meta = ReduceInferMeta(x.meta(), axis, keep_dim);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);

// The real value of reduce_all will be get in kernel
// so use default value(false) is OK.
bool reduce_all = false;

if (x.dtype() == pten::DataType::BOOL || x.dtype() == pten::DataType::INT32 ||
x.dtype() == pten::DataType::INT64) {
dtype = pten::DataType::INT64;
}
Comment on lines +69 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在InferMeta里看到有类似的执行逻辑,这里的逻辑能否仅放在InferMeta或者kernel中处理?如果要代码自动生成的话这类情况可能还需要单独配置,会增加配置项的复杂性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续PR 优化这里


Sum<T>(dev_ctx, x, axis, keep_dim, reduce_all, x.dtype(), dtype, &dense_out);
return dense_out;
}

Expand Down
Loading